Git Product home page Git Product logo

Comments (8)

jingweizhang-xyz avatar jingweizhang-xyz commented on August 16, 2024

You can define your test dataloader and add a line in the main function: trainer.predict(pl_module, test_dataloader). Then set the training epochs to be 0 or comment the trainer.fit(...). Check here if you are not familiar with pytorch lighting: https://lightning.ai/docs/pytorch/stable/deploy/production_basic.html. If you need any help, please let me know. I may work on it after this holiday.

from sampath.

Stark320 avatar Stark320 commented on August 16, 2024

Can you provide your inference script? Thanks

from sampath.

jingweizhang-xyz avatar jingweizhang-xyz commented on August 16, 2024

I am travelling and will provide the inference script after Christmas.

from sampath.

windygoo avatar windygoo commented on August 16, 2024

if name == 'main':
parser = ArgumentParser()
parser.add_argument("--config", default='configs.BCSS', type=str, help="config file path (default: None)")
parser.add_argument('--devices', type=lambda s: [int(item) for item in s.split(',')], default=[0])
parser.add_argument('--project', type=str, default="mFoV")
parser.add_argument('--name', type=str, default="test_sam_prompt")
parser.add_argument('--seed', type=int, default=42)
args = parser.parse_args()

module = __import__(args.config, globals(), locals(), ['cfg'])
cfg = module.cfg

cfg["project"] = args.project
cfg["devices"] = args.devices
cfg["name"] = args.name
cfg["seed"] = args.seed

seed_everything(cfg["seed"])
print(cfg)
# main(cfg)

metrics_calculator = get_metrics(cfg=cfg)

sam_model = get_model(cfg)
ckpt = torch.load(
    'model.ckpt', map_location='cpu'
)

updated_state_dict = {k[6:]: v for k, v in ckpt['state_dict'].items() if k[6:] in sam_model.state_dict()}
sam_model.load_state_dict(updated_state_dict)
sam_model.eval()

import cv2 as cv
import albumentations as A

from albumentations.pytorch import ToTensorV2
from torch.utils.data import Dataset


class ImageMaskDataset(Dataset):
    def __init__(self):
        dataset = 'BCSS'
        mode = 'test'
        with open(f'../datasets/{dataset}/{mode}_files.txt', 'r') as f:
            self.img_paths = f.read().splitlines()

        self.dataset = dataset
        self.transform = A.Compose(
            [getattr(A, tf_dict.pop('type'))(**tf_dict) for tf_dict in cfg.data.get(mode).transform]
            + [ToTensorV2()], p=1)

        import pandas as pd
        import numpy as np

        df = pd.read_csv('/mnt/Xsky/szy/Former/SAMPath/dataset_cfg/BCSS_cv.csv', header=0)
        df = df[df['fold'] < 0]
        self.img_paths = np.asarray(df.iloc[:, 0])

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, index: int):
        assert index <= len(self), 'index range error'

        index = index % len(self)
        # img_path = '../' + self.img_paths[index]
        img_path = f'/mnt/Xsky/szy/Former/datasets/merged_dataset/img/{self.img_paths[index]}'

        image = cv.imread(img_path + '.jpg')
        image = cv.cvtColor(image, cv.COLOR_BGR2RGB)

        mask = cv.imread(img_path.replace('img', 'mask') + '.png', cv.IMREAD_GRAYSCALE)

        ret = self.transform(image=image, mask=mask)
        image, mask = ret["image"], ret["mask"]

        return image, mask.long()


from mmengine.config import Config

cfg = Config.fromfile('../config/BCSS.py')

from torch.utils.data import DataLoader

test_dataset = ImageMaskDataset()
test_loader = DataLoader(
    test_dataset,
    batch_size=cfg.data.batch_size_per_gpu,
    shuffle=False,
    num_workers=cfg.data.num_workers,
    drop_last=False
)

device = 'cuda:0'
metrics_calculator = metrics_calculator.to(device)
import sys

from torchmetrics import MetricCollection, JaccardIndex, F1Score, ClasswiseWrapper

ignore_index = 0
num_classes = 6
epoch_iterator = tqdm.tqdm(test_loader, file=sys.stdout, desc="Test (X / X Steps)",
                           dynamic_ncols=True)
epoch = 0
sam_model.to(device)

for data_iter_step, (images, true_masks) in enumerate(epoch_iterator):
    epoch_iterator.set_description(
        "Epoch=%d: Test (%d / %d Steps) " % (epoch, data_iter_step, len(test_loader)))

    images = images.to(device)
    true_masks = true_masks.to(device)

    ignored_masks = torch.eq(true_masks, 0).long()

    pred_masks = sam_model(images)[0]
    pred_masks = torch.stack(pred_masks, dim=0)

    pred_masks = torch.argmax(pred_masks[:, 1:, ...], dim=1) + 1
    pred_masks = pred_masks * (1 - ignored_masks)

    metrics_calculator.update(pred_masks, true_masks)

print(metrics_calculator.compute())

from sampath.

NaokiThread avatar NaokiThread commented on August 16, 2024

I want to try the pretrained weights (https://wandb.ai/jingwezhang/sam_finetune_loss/reports/BCSS_fusion_focal_0125_iou_00625--Vmlldzo2MzMyMTk3?accessToken=667u6cvye77pufxjwu45g8er2pkvcin06sno9wv11sh6nx96r9618k2rn1jt8kva) on TCGA pathological images. Could you please tell me how I can run the evaluation code? Give me sample codes.

from sampath.

jingweizhang-xyz avatar jingweizhang-xyz commented on August 16, 2024

Have a try using windygoo's script. If it does not work, please let me know.

from sampath.

NaokiThread avatar NaokiThread commented on August 16, 2024

Thank you for your response! With windygoo's script and some revision, I made it to run the inference.

from sampath.

Hsuan2021 avatar Hsuan2021 commented on August 16, 2024

With windygoo's script I still cannot run the inference. Could you please provide your revised script ?

from sampath.

Related Issues (11)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.