Open wangbaoyuanGUET opened 5 months ago
Hi!Dear Developers! Here is my test code, please ask me if I wrote it correctly?
import torch import numpy as np from lib import networks from lib import models from lib.data.med_transforms import * from lib.utils import set_seed, dist_setup, get_conf from monai.losses import DiceCELoss, DiceLoss from collections import defaultdict, OrderedDict from monai.metrics import compute_meandice, compute_hausdorff_distance from functools import partial from lib.data.med_datasets import * from lib.utils import SmoothedValue, concat_all_gather, LayerDecayValueAssigner from monai.inferers import sliding_window_inference from monai.data import decollate_batch import nibabel as nib class Test(): def __init__(self, args): #super().__init__(args, test_path) self.args = args self.model_name = args.proj_name self.scaler = torch.cuda.amp.GradScaler() self.metric_funcs = OrderedDict([('Dice', compute_meandice), ('HD', partial(compute_hausdorff_distance, percentile=95))]) def build_model(self): print(f"=> creating model {self.model_name}") self.loss_fn = DiceCELoss(to_onehot_y=True, softmax=True, squared_pred=True, smooth_nr=args.smooth_nr, smooth_dr=args.smooth_dr) self.post_pred, self.post_label = get_post_transforms(args) self.model = getattr(models, self.model_name)(encoder=getattr(networks, args.enc_arch), decoder=getattr(networks, args.dec_arch), args=args) print(f"=> loading checkpoint") checkpoint = torch.load(args.pretrain, map_location='cpu') state_dict = checkpoint['state_dict'] msg = self.model.load_state_dict(state_dict, strict=False) print(f"Loading messages: \n {msg}") print(f"=> Finish loading pretrained weights from {args.pretrain}") self.model.eval() self.model.cuda(args.gpu) def build_dataloader(self): print("=> creating test dataloader") args = self.args #test_transform = get_test_transforms(args) test_transform = get_testV2_transforms(args) self.val_dataloader = get_val_loader(args, args.batch_size, args.workers, test_transform) @torch.no_grad() def evaluate(self): args = self.args self.build_dataloader() self.build_model() model = self.model dice_list_case = [] print("=> Start Evaluating") val_loader = self.val_dataloader roi_size = (args.roi_x, args.roi_y, args.roi_z) if args.spatial_dim == 3 else None meters = defaultdict(SmoothedValue) ts_samples = int(len(val_loader)) val_samples = len(val_loader) - ts_samples ts_meters = defaultdict(SmoothedValue) for i, batch_data in enumerate(val_loader): image, target = batch_data['image'].to(args.gpu, non_blocking=True), batch_data['label'].to(args.gpu, non_blocking=True) original_affine = batch_data["label_meta_dict"]["affine"][0].numpy() _, _, h, w, d = target.shape target_shape = (h, w, d) img_name = batch_data["image_meta_dict"]["filename_or_obj"][0].split("/")[-1] with torch.cuda.amp.autocast(): val_output = sliding_window_inference(image, roi_size=roi_size, sw_batch_size=4, predictor=model, overlap=args.infer_overlap) val_output = torch.softmax(val_output, 1).cpu().numpy() val_output = np.argmax(val_output, axis=1).astype(np.uint8)[0] target = target.cpu().numpy()[0, 0, :, :, :] val_output = resample_3d(img=val_output, target_size=target_shape) print(f'val_output shape is {val_output.shape} | target shape is {target_shape}') mean_dice = dice(val_output == 1, target == 1) print(f"=>Evaluating on {img_name}, Mean Dice: {mean_dice}") dice_list_case.append(mean_dice) nib.save( nib.Nifti1Image(val_output.astype(np.uint8), original_affine), os.path.join('/home/lzb/wby/3D_Project/SelfMedMAEv2.0/Test_Output', img_name) ) print("Overall Mean Dice: {}".format(np.mean(dice_list_case))) def resample_3d(img, target_size): imx, imy, imz = img.shape tx, ty, tz = target_size zoom_ratio = (float(tx) / float(imx), float(ty) / float(imy), float(tz) / float(imz)) import scipy.ndimage as ndimage img_resampled = ndimage.zoom(img, zoom_ratio, order=0, prefilter=False) return img_resampled def dice(x, y): intersect = np.sum(np.sum(np.sum(x * y))) y_sum = np.sum(np.sum(np.sum(y))) if y_sum == 0: return 0.0 x_sum = np.sum(np.sum(np.sum(x))) return 2 * intersect / (x_sum + y_sum) def compute_avg_metric(metric, meters, metric_name, batch_size, args): assert len(metric.shape) == 2 if args.dataset == 'btcv': # cls_avg_metric = np.nanmean(np.nanmean(metric, axis=0)) cls_avg_metric = np.mean(np.ma.masked_invalid(np.nanmean(metric, axis=0))) # cls8_avg_metric = np.nanmean(np.nanmean(metric[..., btcv_8cls_idx], axis=0)) #cls8_avg_metric = np.nanmean(np.ma.masked_invalid(np.nanmean(metric[..., btcv_8cls_idx], axis=0))) meters[metric_name].update(value=cls_avg_metric, n=batch_size) #meters[f'cls8_{metric_name}'].update(value=cls8_avg_metric, n=batch_size) else: cls_avg_metric = np.nanmean(np.nanmean(metric, axis=0)) meters[metric_name].update(value=cls_avg_metric, n=batch_size) if __name__ == '__main__': args = get_conf() args.test = True args.num_classes = 2 test_example = Test(args) test_example.evaluate()
Hi!Dear Developers! Here is my test code, please ask me if I wrote it correctly?