cvlab-stonybrook / SAMPath

Repository for "SAM-Path: A Segment Anything Model for Semantic Segmentation in Digital Pathology" (MedAGI2023, MICCAI2023 workshop)
28 stars 6 forks source link

How to inference the trained model? #1

Open KeyaoZhao opened 9 months ago

KeyaoZhao commented 9 months ago

Hello! I wonder how to inference the trained model? Can you previde the inference code? Thanks a lot.

jingweizhang-xyz commented 9 months ago

You can define your test dataloader and add a line in the main function: trainer.predict(pl_module, test_dataloader). Then set the training epochs to be 0 or comment the trainer.fit(...). Check here if you are not familiar with pytorch lighting: https://lightning.ai/docs/pytorch/stable/deploy/production_basic.html. If you need any help, please let me know. I may work on it after this holiday.

Stark320 commented 9 months ago

Can you provide your inference script? Thanks

jingweizhang-xyz commented 9 months ago

I am travelling and will provide the inference script after Christmas.

windygoo commented 8 months ago

if name == 'main': parser = ArgumentParser() parser.add_argument("--config", default='configs.BCSS', type=str, help="config file path (default: None)") parser.add_argument('--devices', type=lambda s: [int(item) for item in s.split(',')], default=[0]) parser.add_argument('--project', type=str, default="mFoV") parser.add_argument('--name', type=str, default="test_sam_prompt") parser.add_argument('--seed', type=int, default=42) args = parser.parse_args()

module = __import__(args.config, globals(), locals(), ['cfg'])
cfg = module.cfg

cfg["project"] = args.project
cfg["devices"] = args.devices
cfg["name"] = args.name
cfg["seed"] = args.seed

seed_everything(cfg["seed"])
print(cfg)
# main(cfg)

metrics_calculator = get_metrics(cfg=cfg)

sam_model = get_model(cfg)
ckpt = torch.load(
    'model.ckpt', map_location='cpu'
)

updated_state_dict = {k[6:]: v for k, v in ckpt['state_dict'].items() if k[6:] in sam_model.state_dict()}
sam_model.load_state_dict(updated_state_dict)
sam_model.eval()

import cv2 as cv
import albumentations as A

from albumentations.pytorch import ToTensorV2
from torch.utils.data import Dataset

class ImageMaskDataset(Dataset):
    def __init__(self):
        dataset = 'BCSS'
        mode = 'test'
        with open(f'../datasets/{dataset}/{mode}_files.txt', 'r') as f:
            self.img_paths = f.read().splitlines()

        self.dataset = dataset
        self.transform = A.Compose(
            [getattr(A, tf_dict.pop('type'))(**tf_dict) for tf_dict in cfg.data.get(mode).transform]
            + [ToTensorV2()], p=1)

        import pandas as pd
        import numpy as np

        df = pd.read_csv('/mnt/Xsky/szy/Former/SAMPath/dataset_cfg/BCSS_cv.csv', header=0)
        df = df[df['fold'] < 0]
        self.img_paths = np.asarray(df.iloc[:, 0])

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, index: int):
        assert index <= len(self), 'index range error'

        index = index % len(self)
        # img_path = '../' + self.img_paths[index]
        img_path = f'/mnt/Xsky/szy/Former/datasets/merged_dataset/img/{self.img_paths[index]}'

        image = cv.imread(img_path + '.jpg')
        image = cv.cvtColor(image, cv.COLOR_BGR2RGB)

        mask = cv.imread(img_path.replace('img', 'mask') + '.png', cv.IMREAD_GRAYSCALE)

        ret = self.transform(image=image, mask=mask)
        image, mask = ret["image"], ret["mask"]

        return image, mask.long()

from mmengine.config import Config

cfg = Config.fromfile('../config/BCSS.py')

from torch.utils.data import DataLoader

test_dataset = ImageMaskDataset()
test_loader = DataLoader(
    test_dataset,
    batch_size=cfg.data.batch_size_per_gpu,
    shuffle=False,
    num_workers=cfg.data.num_workers,
    drop_last=False
)

device = 'cuda:0'
metrics_calculator = metrics_calculator.to(device)
import sys

from torchmetrics import MetricCollection, JaccardIndex, F1Score, ClasswiseWrapper

ignore_index = 0
num_classes = 6
epoch_iterator = tqdm.tqdm(test_loader, file=sys.stdout, desc="Test (X / X Steps)",
                           dynamic_ncols=True)
epoch = 0
sam_model.to(device)

for data_iter_step, (images, true_masks) in enumerate(epoch_iterator):
    epoch_iterator.set_description(
        "Epoch=%d: Test (%d / %d Steps) " % (epoch, data_iter_step, len(test_loader)))

    images = images.to(device)
    true_masks = true_masks.to(device)

    ignored_masks = torch.eq(true_masks, 0).long()

    pred_masks = sam_model(images)[0]
    pred_masks = torch.stack(pred_masks, dim=0)

    pred_masks = torch.argmax(pred_masks[:, 1:, ...], dim=1) + 1
    pred_masks = pred_masks * (1 - ignored_masks)

    metrics_calculator.update(pred_masks, true_masks)

print(metrics_calculator.compute())
NaokiThread commented 7 months ago

I want to try the pretrained weights (https://wandb.ai/jingwezhang/sam_finetune_loss/reports/BCSS_fusion_focal_0125_iou_00625--Vmlldzo2MzMyMTk3?accessToken=667u6cvye77pufxjwu45g8er2pkvcin06sno9wv11sh6nx96r9618k2rn1jt8kva) on TCGA pathological images. Could you please tell me how I can run the evaluation code? Give me sample codes.

jingweizhang-xyz commented 7 months ago

Have a try using windygoo's script. If it does not work, please let me know.

NaokiThread commented 6 months ago

Thank you for your response! With windygoo's script and some revision, I made it to run the inference.

Hsuan2021 commented 2 months ago

With windygoo's script I still cannot run the inference. Could you please provide your revised script ?