cvlab-stonybrook / SelfMedMAE

Code for ISBI 2023 paper "Self Pre-training with Masked Autoencoders for Medical Image Classification and Segmentation"
Apache License 2.0
112 stars 12 forks source link

A issue when I want to test on my own dataset #13

Open wangbaoyuanGUET opened 5 months ago

wangbaoyuanGUET commented 5 months ago

Hi!Dear Developers! Here is my test code, please ask me if I wrote it correctly?

import torch
import numpy as np
from lib import networks
from lib import models
from lib.data.med_transforms import *
from lib.utils import set_seed, dist_setup, get_conf
from monai.losses import DiceCELoss, DiceLoss
from collections import defaultdict, OrderedDict
from monai.metrics import compute_meandice, compute_hausdorff_distance
from functools import partial
from lib.data.med_datasets import *
from lib.utils import SmoothedValue, concat_all_gather, LayerDecayValueAssigner
from monai.inferers import sliding_window_inference
from monai.data import decollate_batch
import nibabel as nib 

class Test():
    def __init__(self, args):
        #super().__init__(args, test_path)
        self.args = args
        self.model_name = args.proj_name
        self.scaler = torch.cuda.amp.GradScaler()
        self.metric_funcs = OrderedDict([('Dice', compute_meandice), ('HD', partial(compute_hausdorff_distance, percentile=95))])

    def build_model(self):
        print(f"=> creating model {self.model_name}")

        self.loss_fn = DiceCELoss(to_onehot_y=True,
                                          softmax=True,
                                          squared_pred=True,
                                          smooth_nr=args.smooth_nr,
                                          smooth_dr=args.smooth_dr)
        self.post_pred, self.post_label = get_post_transforms(args)
        self.model = getattr(models, self.model_name)(encoder=getattr(networks, args.enc_arch),
                                                          decoder=getattr(networks, args.dec_arch),
                                                          args=args)
        print(f"=> loading checkpoint")
        checkpoint = torch.load(args.pretrain, map_location='cpu')
        state_dict = checkpoint['state_dict']
        msg = self.model.load_state_dict(state_dict, strict=False)
        print(f"Loading messages: \n {msg}")
        print(f"=> Finish loading pretrained weights from {args.pretrain}")
        self.model.eval()
        self.model.cuda(args.gpu)

    def build_dataloader(self):
        print("=> creating test dataloader")
        args = self.args
        #test_transform = get_test_transforms(args)
        test_transform = get_testV2_transforms(args)

        self.val_dataloader = get_val_loader(args, args.batch_size, args.workers, test_transform)

    @torch.no_grad()
    def evaluate(self):
        args = self.args
        self.build_dataloader()
        self.build_model()
        model = self.model
        dice_list_case = []
        print("=> Start Evaluating")
        val_loader = self.val_dataloader        
        roi_size = (args.roi_x, args.roi_y, args.roi_z) if args.spatial_dim == 3 else None
        meters = defaultdict(SmoothedValue)
        ts_samples = int(len(val_loader))
        val_samples = len(val_loader) - ts_samples
        ts_meters = defaultdict(SmoothedValue)

        for i, batch_data in enumerate(val_loader):
            image, target = batch_data['image'].to(args.gpu, non_blocking=True), batch_data['label'].to(args.gpu, non_blocking=True)
            original_affine = batch_data["label_meta_dict"]["affine"][0].numpy()
            _, _, h, w, d = target.shape
            target_shape = (h, w, d)
            img_name = batch_data["image_meta_dict"]["filename_or_obj"][0].split("/")[-1]

            with torch.cuda.amp.autocast():
                val_output = sliding_window_inference(image, roi_size=roi_size, sw_batch_size=4, predictor=model, overlap=args.infer_overlap)
                val_output = torch.softmax(val_output, 1).cpu().numpy()
                val_output = np.argmax(val_output, axis=1).astype(np.uint8)[0]
                target = target.cpu().numpy()[0, 0, :, :, :]
                val_output = resample_3d(img=val_output, target_size=target_shape)
                print(f'val_output shape is {val_output.shape} | target shape is {target_shape}')
                mean_dice = dice(val_output == 1, target == 1)
                print(f"=>Evaluating on {img_name}, Mean Dice: {mean_dice}")    
                dice_list_case.append(mean_dice)
                nib.save(
                    nib.Nifti1Image(val_output.astype(np.uint8), original_affine), os.path.join('/home/lzb/wby/3D_Project/SelfMedMAEv2.0/Test_Output', img_name)
                )
        print("Overall Mean Dice: {}".format(np.mean(dice_list_case)))

def resample_3d(img, target_size):
    imx, imy, imz = img.shape
    tx, ty, tz = target_size
    zoom_ratio = (float(tx) / float(imx), float(ty) / float(imy), float(tz) / float(imz))
    import scipy.ndimage as ndimage
    img_resampled = ndimage.zoom(img, zoom_ratio, order=0, prefilter=False)
    return img_resampled

def dice(x, y):
    intersect = np.sum(np.sum(np.sum(x * y)))
    y_sum = np.sum(np.sum(np.sum(y)))
    if y_sum == 0:
         return 0.0
    x_sum = np.sum(np.sum(np.sum(x)))
    return 2 * intersect / (x_sum + y_sum)

def compute_avg_metric(metric, meters, metric_name, batch_size, args):
    assert len(metric.shape) == 2
    if args.dataset == 'btcv':
        # cls_avg_metric = np.nanmean(np.nanmean(metric, axis=0))
        cls_avg_metric = np.mean(np.ma.masked_invalid(np.nanmean(metric, axis=0)))
        # cls8_avg_metric = np.nanmean(np.nanmean(metric[..., btcv_8cls_idx], axis=0))
        #cls8_avg_metric = np.nanmean(np.ma.masked_invalid(np.nanmean(metric[..., btcv_8cls_idx], axis=0)))
        meters[metric_name].update(value=cls_avg_metric, n=batch_size)
        #meters[f'cls8_{metric_name}'].update(value=cls8_avg_metric, n=batch_size)
    else:
        cls_avg_metric = np.nanmean(np.nanmean(metric, axis=0))
        meters[metric_name].update(value=cls_avg_metric, n=batch_size)

if __name__ == '__main__':
    args = get_conf()
    args.test = True
    args.num_classes = 2
    test_example = Test(args)
    test_example.evaluate()