Open KeyaoZhao opened 9 months ago
You can define your test dataloader and add a line in the main function: trainer.predict(pl_module, test_dataloader). Then set the training epochs to be 0 or comment the trainer.fit(...). Check here if you are not familiar with pytorch lighting: https://lightning.ai/docs/pytorch/stable/deploy/production_basic.html. If you need any help, please let me know. I may work on it after this holiday.
Can you provide your inference script? Thanks
I am travelling and will provide the inference script after Christmas.
if name == 'main': parser = ArgumentParser() parser.add_argument("--config", default='configs.BCSS', type=str, help="config file path (default: None)") parser.add_argument('--devices', type=lambda s: [int(item) for item in s.split(',')], default=[0]) parser.add_argument('--project', type=str, default="mFoV") parser.add_argument('--name', type=str, default="test_sam_prompt") parser.add_argument('--seed', type=int, default=42) args = parser.parse_args()
module = __import__(args.config, globals(), locals(), ['cfg'])
cfg = module.cfg
cfg["project"] = args.project
cfg["devices"] = args.devices
cfg["name"] = args.name
cfg["seed"] = args.seed
seed_everything(cfg["seed"])
print(cfg)
# main(cfg)
metrics_calculator = get_metrics(cfg=cfg)
sam_model = get_model(cfg)
ckpt = torch.load(
'model.ckpt', map_location='cpu'
)
updated_state_dict = {k[6:]: v for k, v in ckpt['state_dict'].items() if k[6:] in sam_model.state_dict()}
sam_model.load_state_dict(updated_state_dict)
sam_model.eval()
import cv2 as cv
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torch.utils.data import Dataset
class ImageMaskDataset(Dataset):
def __init__(self):
dataset = 'BCSS'
mode = 'test'
with open(f'../datasets/{dataset}/{mode}_files.txt', 'r') as f:
self.img_paths = f.read().splitlines()
self.dataset = dataset
self.transform = A.Compose(
[getattr(A, tf_dict.pop('type'))(**tf_dict) for tf_dict in cfg.data.get(mode).transform]
+ [ToTensorV2()], p=1)
import pandas as pd
import numpy as np
df = pd.read_csv('/mnt/Xsky/szy/Former/SAMPath/dataset_cfg/BCSS_cv.csv', header=0)
df = df[df['fold'] < 0]
self.img_paths = np.asarray(df.iloc[:, 0])
def __len__(self):
return len(self.img_paths)
def __getitem__(self, index: int):
assert index <= len(self), 'index range error'
index = index % len(self)
# img_path = '../' + self.img_paths[index]
img_path = f'/mnt/Xsky/szy/Former/datasets/merged_dataset/img/{self.img_paths[index]}'
image = cv.imread(img_path + '.jpg')
image = cv.cvtColor(image, cv.COLOR_BGR2RGB)
mask = cv.imread(img_path.replace('img', 'mask') + '.png', cv.IMREAD_GRAYSCALE)
ret = self.transform(image=image, mask=mask)
image, mask = ret["image"], ret["mask"]
return image, mask.long()
from mmengine.config import Config
cfg = Config.fromfile('../config/BCSS.py')
from torch.utils.data import DataLoader
test_dataset = ImageMaskDataset()
test_loader = DataLoader(
test_dataset,
batch_size=cfg.data.batch_size_per_gpu,
shuffle=False,
num_workers=cfg.data.num_workers,
drop_last=False
)
device = 'cuda:0'
metrics_calculator = metrics_calculator.to(device)
import sys
from torchmetrics import MetricCollection, JaccardIndex, F1Score, ClasswiseWrapper
ignore_index = 0
num_classes = 6
epoch_iterator = tqdm.tqdm(test_loader, file=sys.stdout, desc="Test (X / X Steps)",
dynamic_ncols=True)
epoch = 0
sam_model.to(device)
for data_iter_step, (images, true_masks) in enumerate(epoch_iterator):
epoch_iterator.set_description(
"Epoch=%d: Test (%d / %d Steps) " % (epoch, data_iter_step, len(test_loader)))
images = images.to(device)
true_masks = true_masks.to(device)
ignored_masks = torch.eq(true_masks, 0).long()
pred_masks = sam_model(images)[0]
pred_masks = torch.stack(pred_masks, dim=0)
pred_masks = torch.argmax(pred_masks[:, 1:, ...], dim=1) + 1
pred_masks = pred_masks * (1 - ignored_masks)
metrics_calculator.update(pred_masks, true_masks)
print(metrics_calculator.compute())
I want to try the pretrained weights (https://wandb.ai/jingwezhang/sam_finetune_loss/reports/BCSS_fusion_focal_0125_iou_00625--Vmlldzo2MzMyMTk3?accessToken=667u6cvye77pufxjwu45g8er2pkvcin06sno9wv11sh6nx96r9618k2rn1jt8kva) on TCGA pathological images. Could you please tell me how I can run the evaluation code? Give me sample codes.
Have a try using windygoo's script. If it does not work, please let me know.
Thank you for your response! With windygoo's script and some revision, I made it to run the inference.
With windygoo's script I still cannot run the inference. Could you please provide your revised script ?
Hello! I wonder how to inference the trained model? Can you previde the inference code? Thanks a lot.