cvlab-stonybrook / SAMPath

Repository for "SAM-Path: A Segment Anything Model for Semantic Segmentation in Digital Pathology" (MedAGI2023, MICCAI2023 workshop)
23 stars 6 forks source link

Does the BCSS.py configuration file match the predict.py ? #10

Open NanCheng2001 opened 3 months ago

NanCheng2001 commented 3 months ago

I used the predict file you mentioned in the comments section. My predict.py file is as follows:

import cv2 as cv
import albumentations as A
from argparse import ArgumentParser
from albumentations.pytorch import ToTensorV2
from torch.utils.data import Dataset
from argparse import ArgumentParser
from pytorch_lightning import seed_everything
from main import get_model,get_metrics
import torch,tqdm
from mmengine import Config

class ImageMaskDataset(Dataset):
    def __init__(self):
        dataset = 'BCSS'
        mode = 'test'
        # with open(f'../datasets/{dataset}/{mode}_files.txt', 'r') as f:
        #     self.img_paths = f.read().splitlines()

        self.dataset = dataset
        self.transform = A.Compose(
            [getattr(A, tf_dict.pop('type'))(**tf_dict) for tf_dict in cfg.data.get(mode).transform]
            + [ToTensorV2()], p=1)

        import pandas as pd
        import numpy as np

        df = pd.read_csv('//mnt/project/SAM/SAMPath/SAMPath/dataset_cfg/BCSS_cv.csv', header=0)
        df = df[df['fold'] < 0]
        self.img_paths = np.asarray(df.iloc[:, 0])

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, index: int):
        assert index <= len(self), 'index range error'

        index = index % len(self)
        # img_path = '../' + self.img_paths[index]
        img_path = f'/mnt/dataset/BCSS/merged_dataset/img/{self.img_paths[index]}'

        image = cv.imread(img_path + '.jpg')
        image = cv.cvtColor(image, cv.COLOR_BGR2RGB)

        mask = cv.imread(img_path.replace('img', 'mask') + '.png', cv.IMREAD_GRAYSCALE)

        ret = self.transform(image=image, mask=mask)
        image, mask = ret["image"], ret["mask"]

        return image, mask.long()

if __name__ == '__main__':
    parser = ArgumentParser()
    parser.add_argument("--config", default='configs.BCSS', type=str, help="config file path (default: None)")
    parser.add_argument('--devices', type=lambda s: [int(item) for item in s.split(',')], default=[0])
    parser.add_argument('--project', type=str, default="mFoV")
    parser.add_argument('--name', type=str, default="test_sam_prompt")
    parser.add_argument('--seed', type=int, default=42)
    args = parser.parse_args()

    module = __import__(args.config, globals(), locals(), ['cfg'])
    cfg = module.cfg

    cfg["project"] = args.project
    cfg["devices"] = args.devices
    cfg["name"] = args.name
    cfg["seed"] = args.seed

    seed_everything(cfg["seed"])
    print(cfg)
    # main(cfg)

    metrics_calculator = get_metrics(cfg=cfg)

    sam_model = get_model(cfg)
    ckpt = torch.load(
        '/mnt/project/SAM/SAMPath/SAMPath/checkpoints/model.ckpt', map_location='cuda:0'
    )

    updated_state_dict = {k[6:]: v for k, v in ckpt['state_dict'].items() if k[6:] in sam_model.state_dict()}
    sam_model.load_state_dict(updated_state_dict)
    sam_model.eval()
    cfg = Config.fromfile('/mnt/project/SAM/SAMPath/SAMPath/configs/BCSS.py')

    from torch.utils.data import DataLoader

    test_dataset = ImageMaskDataset()

    # test_loader = DataLoader(
    #     test_dataset,
    #     batch_size=cfg.batch_size,
    #     shuffle=False,
    #     num_workers=cfg.num_workers,
    #     drop_last=False
    # )

    test_loader = DataLoader(
        test_dataset,
        batch_size=cfg.data.batch_size_per_gpu,
        shuffle=False,
        num_workers=cfg.data.num_workers,
        drop_last=False
    )

    device = 'cuda:0'
    metrics_calculator = metrics_calculator.to(device)
    import sys

    from torchmetrics import MetricCollection, JaccardIndex, F1Score, ClasswiseWrapper

    ignore_index = 0
    num_classes = 6
    epoch_iterator = tqdm.tqdm(test_loader, file=sys.stdout, desc="Test (X / X Steps)",
                               dynamic_ncols=True)
    epoch = 0
    sam_model.to(device)

    for data_iter_step, (images, true_masks) in enumerate(epoch_iterator):
        epoch_iterator.set_description(
            "Epoch=%d: Test (%d / %d Steps) " % (epoch, data_iter_step, len(test_loader)))

        images = images.to(device)
        true_masks = true_masks.to(device)

        ignored_masks = torch.eq(true_masks, 0).long()

        pred_masks = sam_model(images)[0]
        pred_masks = torch.stack(pred_masks, dim=0)

        pred_masks = torch.argmax(pred_masks[:, 1:, ...], dim=1) + 1
        pred_masks = pred_masks * (1 - ignored_masks)

        metrics_calculator.update(pred_masks, true_masks)

    print(metrics_calculator.compute())

However, an error occurred. After my summary, I believe that the error is caused by the code in that line: cfg = Config.fromfile('/mnt/project/SAM/SAMPath/SAMPath/configs/BCSS.py') Similarly, my BCSS.py configuration file is as follows:

from box import Box

config = {
    "batch_size": 6,
    "accumulate_grad_batches": 2,
    "num_workers": 4,
    "out_dir": "/mnt/project/SAM/SAMPath/SAMPath/output",
    "opt": {
        "num_epochs": 32,
        "learning_rate": 1e-4,
        "weight_decay": 1e-2, #1e-2,
        "precision": 32, # "16-mixed"
        "steps":  [72 * 25, 72 * 29],
        "warmup_steps": 72,
    },
    "model": {
        "type": 'vit_b',
        "checkpoint": "/mnt/project/SAM/SAMPath/SAMPath/checkpoints/sam_vit_b_01ec64.pth",
        "freeze": {
            "image_encoder": True,
            "prompt_encoder": True,
            "mask_decoder": False,
        },
        "prompt_dim": 256,
        "prompt_decoder": False,
        "dense_prompt_decoder": False,

        "extra_encoder": 'hipt',
        "extra_type": "fusion",
        "extra_checkpoint":  "/mnt/project/SAM/SAMPath/SAMPath/checkpoints/vit256_small_dino.pth",
    },
    "loss": {
        "focal_cof": 0.25,
        "dice_cof": 0.75,
        "ce_cof": 0.0,
        "iou_cof": 0.0625,
    },
    "dataset": {
        "dataset_root": "/mnt/dataset/BCSS/merged_dataset",
        "dataset_csv_path": "/mnt/project/SAM/SAMPath/SAMPath/dataset_cfg/BCSS_cv.csv",
        "val_fold_id": 0,
        "num_classes": 6,

        "ignored_classes": (0),
        "ignored_classes_metric": None, # if we do not count background, set to 1 (bg class)
        "image_hw": (1024, 1024), # default is 1024, 1024

        "feature_input": False, # or "True" for *.pt features
        "dataset_mean": (0.485, 0.456, 0.406),
        "dataset_std": (0.229, 0.224, 0.225),
    }
}

cfg = Box(config)

The error situation is as follows:

Traceback (most recent call last):
  File "/mnt/project/SAM/SAMPath/SAMPath/predict.py", line 87, in <module>
    cfg = Config.fromfile('/mnt/project/SAM/SAMPath/SAMPath/configs/BCSS.py')
  File "/home/pc2080ti/anaconda3/envs/SAMpath/lib/python3.8/site-packages/mmengine/config/config.py", line 492, in fromfile
    raise e
  File "/home/pc2080ti/anaconda3/envs/SAMpath/lib/python3.8/site-packages/mmengine/config/config.py", line 490, in fromfile
    cfg_dict, imported_names = Config._parse_lazy_import(filename)
  File "/home/pc2080ti/anaconda3/envs/SAMpath/lib/python3.8/site-packages/mmengine/config/config.py", line 1105, in _parse_lazy_import
    exec(
  File "/mnt/project/SAM/SAMPath/SAMPath/configs/BCSS.py", line 54, in <module>
    cfg = Box(config)
  File "/home/pc2080ti/anaconda3/envs/SAMpath/lib/python3.8/site-packages/mmengine/config/lazy.py", line 103, in __call__
    raise RuntimeError()
RuntimeError

I feel that the BCSS.py configuration file does not match the one used in predict.py?

simzhangbest commented 2 weeks ago

Does the scripts predict.py work ok?