I used the predict file you mentioned in the comments section. My predict.py file is as follows:
import cv2 as cv
import albumentations as A
from argparse import ArgumentParser
from albumentations.pytorch import ToTensorV2
from torch.utils.data import Dataset
from argparse import ArgumentParser
from pytorch_lightning import seed_everything
from main import get_model,get_metrics
import torch,tqdm
from mmengine import Config
class ImageMaskDataset(Dataset):
def __init__(self):
dataset = 'BCSS'
mode = 'test'
# with open(f'../datasets/{dataset}/{mode}_files.txt', 'r') as f:
# self.img_paths = f.read().splitlines()
self.dataset = dataset
self.transform = A.Compose(
[getattr(A, tf_dict.pop('type'))(**tf_dict) for tf_dict in cfg.data.get(mode).transform]
+ [ToTensorV2()], p=1)
import pandas as pd
import numpy as np
df = pd.read_csv('//mnt/project/SAM/SAMPath/SAMPath/dataset_cfg/BCSS_cv.csv', header=0)
df = df[df['fold'] < 0]
self.img_paths = np.asarray(df.iloc[:, 0])
def __len__(self):
return len(self.img_paths)
def __getitem__(self, index: int):
assert index <= len(self), 'index range error'
index = index % len(self)
# img_path = '../' + self.img_paths[index]
img_path = f'/mnt/dataset/BCSS/merged_dataset/img/{self.img_paths[index]}'
image = cv.imread(img_path + '.jpg')
image = cv.cvtColor(image, cv.COLOR_BGR2RGB)
mask = cv.imread(img_path.replace('img', 'mask') + '.png', cv.IMREAD_GRAYSCALE)
ret = self.transform(image=image, mask=mask)
image, mask = ret["image"], ret["mask"]
return image, mask.long()
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument("--config", default='configs.BCSS', type=str, help="config file path (default: None)")
parser.add_argument('--devices', type=lambda s: [int(item) for item in s.split(',')], default=[0])
parser.add_argument('--project', type=str, default="mFoV")
parser.add_argument('--name', type=str, default="test_sam_prompt")
parser.add_argument('--seed', type=int, default=42)
args = parser.parse_args()
module = __import__(args.config, globals(), locals(), ['cfg'])
cfg = module.cfg
cfg["project"] = args.project
cfg["devices"] = args.devices
cfg["name"] = args.name
cfg["seed"] = args.seed
seed_everything(cfg["seed"])
print(cfg)
# main(cfg)
metrics_calculator = get_metrics(cfg=cfg)
sam_model = get_model(cfg)
ckpt = torch.load(
'/mnt/project/SAM/SAMPath/SAMPath/checkpoints/model.ckpt', map_location='cuda:0'
)
updated_state_dict = {k[6:]: v for k, v in ckpt['state_dict'].items() if k[6:] in sam_model.state_dict()}
sam_model.load_state_dict(updated_state_dict)
sam_model.eval()
cfg = Config.fromfile('/mnt/project/SAM/SAMPath/SAMPath/configs/BCSS.py')
from torch.utils.data import DataLoader
test_dataset = ImageMaskDataset()
# test_loader = DataLoader(
# test_dataset,
# batch_size=cfg.batch_size,
# shuffle=False,
# num_workers=cfg.num_workers,
# drop_last=False
# )
test_loader = DataLoader(
test_dataset,
batch_size=cfg.data.batch_size_per_gpu,
shuffle=False,
num_workers=cfg.data.num_workers,
drop_last=False
)
device = 'cuda:0'
metrics_calculator = metrics_calculator.to(device)
import sys
from torchmetrics import MetricCollection, JaccardIndex, F1Score, ClasswiseWrapper
ignore_index = 0
num_classes = 6
epoch_iterator = tqdm.tqdm(test_loader, file=sys.stdout, desc="Test (X / X Steps)",
dynamic_ncols=True)
epoch = 0
sam_model.to(device)
for data_iter_step, (images, true_masks) in enumerate(epoch_iterator):
epoch_iterator.set_description(
"Epoch=%d: Test (%d / %d Steps) " % (epoch, data_iter_step, len(test_loader)))
images = images.to(device)
true_masks = true_masks.to(device)
ignored_masks = torch.eq(true_masks, 0).long()
pred_masks = sam_model(images)[0]
pred_masks = torch.stack(pred_masks, dim=0)
pred_masks = torch.argmax(pred_masks[:, 1:, ...], dim=1) + 1
pred_masks = pred_masks * (1 - ignored_masks)
metrics_calculator.update(pred_masks, true_masks)
print(metrics_calculator.compute())
However, an error occurred. After my summary, I believe that the error is caused by the code in that line:cfg = Config.fromfile('/mnt/project/SAM/SAMPath/SAMPath/configs/BCSS.py')
Similarly, my BCSS.py configuration file is as follows:
Traceback (most recent call last):
File "/mnt/project/SAM/SAMPath/SAMPath/predict.py", line 87, in <module>
cfg = Config.fromfile('/mnt/project/SAM/SAMPath/SAMPath/configs/BCSS.py')
File "/home/pc2080ti/anaconda3/envs/SAMpath/lib/python3.8/site-packages/mmengine/config/config.py", line 492, in fromfile
raise e
File "/home/pc2080ti/anaconda3/envs/SAMpath/lib/python3.8/site-packages/mmengine/config/config.py", line 490, in fromfile
cfg_dict, imported_names = Config._parse_lazy_import(filename)
File "/home/pc2080ti/anaconda3/envs/SAMpath/lib/python3.8/site-packages/mmengine/config/config.py", line 1105, in _parse_lazy_import
exec(
File "/mnt/project/SAM/SAMPath/SAMPath/configs/BCSS.py", line 54, in <module>
cfg = Box(config)
File "/home/pc2080ti/anaconda3/envs/SAMpath/lib/python3.8/site-packages/mmengine/config/lazy.py", line 103, in __call__
raise RuntimeError()
RuntimeError
I feel that the BCSS.py configuration file does not match the one used in predict.py?
I used the predict file you mentioned in the comments section. My predict.py file is as follows:
However, an error occurred. After my summary, I believe that the error is caused by the code in that line: cfg = Config.fromfile('/mnt/project/SAM/SAMPath/SAMPath/configs/BCSS.py') Similarly, my BCSS.py configuration file is as follows:
The error situation is as follows:
I feel that the BCSS.py configuration file does not match the one used in predict.py?