MedicineToken / MedSegDiff

Medical Image Segmentation with Diffusion Model
MIT License
1.05k stars 165 forks source link

difference between sample, x_noisy, org, cal, cal_out #64

Closed alqurri77 closed 1 year ago

alqurri77 commented 1 year ago

Hi;

I need to calculate the model accuracy (example dice loss). Hence, I need the model predication and grounds truth. What is the model predication out of those? : sample, x_noisy, org, cal, cal_out What each one means.

sample, x_noisy, org, cal, cal_out = sample_fn(
                model,
                (args.batch_size, 3, args.image_size, args.image_size), img,
                step = args.diffusion_steps,
                clip_denoised=args.clip_denoised,
                model_kwargs=model_kwargs,
            )
saisusmitha commented 1 year ago

as far as I know - the model prediction mask is "sample" in the above formula. Can you give code to extract the ground truth mask for the same you mentioned?

alqurri77 commented 1 year ago

Below is the code for sampling.py . but I'm not sure ... for example why th.tensor(sample)[:,-1,:,:].unsqueeze(1) instead of just 'sample'


from torch.nn.modules.loss import CrossEntropyLoss
import io as ahmed
import argparse
import os
from ssl import OP_NO_TLSv1
import nibabel as nib
# from visdom import Visdom
# viz = Visdom(port=8850)
import sys
import random
sys.path.append(".")
import numpy as np
import time
import torch as th
from PIL import Image
import torch.distributed as dist
'''
from guided_diffusion import dist_util, logger
from guided_diffusion.bratsloader import BRATSDataset, BRATSDataset3D
from guided_diffusion.isicloader import ISICDataset
import torchvision.utils as vutils
from guided_diffusion.utils import staple
from guided_diffusion.script_util import (
    NUM_CLASSES,
    model_and_diffusion_defaults,
    create_model_and_diffusion,
    add_dict_to_argparser,
    args_to_dict,
)
'''
import torchvision.transforms as transforms
from torchsummary import summary

#--------------------
dice_loss = CrossEntropyLoss()# DiceLoss(1) #
val_losses = []
#-------------

seed=10
th.manual_seed(seed)
th.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)

def visualize(img):
    _min = img.min()
    _max = img.max()
    normalized_img = (img - _min)/ (_max - _min)
    return normalized_img

def main():
    args = create_argparser()#.parse_args()
    setup_dist(args)
    configure(dir = args.out_dir)
    print("args.data_name ",args.data_name )
    if args.data_name == 'ISIC':
        tran_list = [transforms.Resize((args.image_size,args.image_size)), transforms.ToTensor(),]
        transform_test = transforms.Compose(tran_list)

        ds = ISICDataset(args, args.data_dir, transform_test, mode = 'Test')
        args.in_ch = 4
    elif args.data_name == 'BRATS':
        tran_list = [transforms.Resize((args.image_size,args.image_size)),]
        transform_test = transforms.Compose(tran_list)

        ds = BRATSDataset3D(args.data_dir,transform_test)
        args.in_ch = 5
    datal = th.utils.data.DataLoader(
        ds,
        batch_size=1,
        shuffle=True)
    data = iter(datal)

    log("creating model and diffusion...")

    model, diffusion = create_model_and_diffusion(
        **args_to_dict(args, model_and_diffusion_defaults().keys())
    )
    all_images = []

    state_dict = load_state_dict(args.model_path, map_location="cpu")
    from collections import OrderedDict
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        # name = k[7:] # remove `module.`
        if 'module.' in k:
            new_state_dict[k[7:]] = v
            # load params
        else:
            new_state_dict = state_dict

    #----------------------------------------
    model.load_state_dict(new_state_dict)

    model.to(dev())
    if args.use_fp16:
        model.convert_to_fp16()
    model.eval()
    #while len(all_images) * args.batch_size < args.num_samples:
    my_cou=0
    while my_cou * args.batch_size < args.num_samples:
        print("len(all_images)= ",len(all_images),"args.batch_size= ",  args.batch_size, " args.num_samples ", args.num_samples )
        b, m = next(data)  #should return an image from the dataloader "data"
        c = th.randn_like(b[:, :1, ...])
        img = th.cat((b, c), dim=1)     #add a noise channel$
        if args.data_name == 'ISIC':
            slice_ID="1000"#path[0].split("_")[-1].split('.')[0]
        elif args.data_name == 'BRATS':
            # slice_ID=path[0].split("_")[2] + "_" + path[0].split("_")[4]
            slice_ID="1000"#path[0].split("_")[-3] + "_" + path[0].split("slice")[-1].split('.nii')[0]

        log("sampling...")

        start = th.cuda.Event(enable_timing=True)
        end = th.cuda.Event(enable_timing=True)
        enslist = []

        for i in range(args.num_ensemble):  #this is for the generation of an ensemble of 5 masks.
            print("i= ",i,"args.num_ensemble= ", args.num_ensemble)
            model_kwargs = {}
            start.record()
            sample_fn = (
                diffusion.p_sample_loop_known if not args.use_ddim else diffusion.ddim_sample_loop_known
            )
            sample, x_noisy, org, cal, cal_out = sample_fn(
                model,
                (args.batch_size, 3, args.image_size, args.image_size), img,
                step = args.diffusion_steps,
                clip_denoised=args.clip_denoised,
                model_kwargs=model_kwargs,
            )

            end.record()
            th.cuda.synchronize()
            print('time for 1 sample', start.elapsed_time(end))  #time measurement for the generation of 1 sample

            co = th.tensor(cal_out)
            enslist.append(co)
           #-------------------------------------------
            print("sample",sample.shape)
            print("org",th.tensor(org)[:,:-1,:,:].shape)
            #val_loss =dice_loss (sample, th.tensor(org)[:,:-1,:,:], softmax=True)
            sample2=th.tensor(sample)[:,-1,:,:].unsqueeze(1)
            print("sample2",sample2.shape)

            target =m.cpu()#th.tensor(org)[:,:-1,:,:]# torch.argmax(th.tensor(org)[:,:-1,:,:], dim=1)#th.tensor(org)[:,:-1,:,:]# torch.argmax(th.tensor(org)[:,:-1,:,:], dim=1)
            print("target",target.shape)
            val_loss = dice_loss (sample2.cpu(),target [:] )#        .long() )#  softmax=True)      #
            val_losses.append(val_loss.item())
            #-------------------------------------

            if args.debug:
                # print('sample size is',sample.size())
                # print('org size is',org.size())
                # print('cal size is',cal.size())
                if args.data_name == 'ISIC':
                    s = th.tensor(sample)[:,-1,:,:].unsqueeze(1).repeat(1, 3, 1, 1)
                    o = th.tensor(org)[:,:-1,:,:]
                    c = th.tensor(cal).repeat(1, 3, 1, 1)
                    co = co.repeat(1, 3, 1, 1)
                    print("o",o.shape)
                    print("s",s.shape)
                    print("c",c.shape)
                    print("co",co.shape)
                elif args.data_name == 'BRATS':
                    s = th.tensor(sample)[:,-1,:,:].unsqueeze(1)
                    m = th.tensor(m.to(device = 'cuda:0'))[:,0,:,:].unsqueeze(1)
                    o1 = th.tensor(org)[:,0,:,:].unsqueeze(1)
                    o2 = th.tensor(org)[:,1,:,:].unsqueeze(1)
                    o3 = th.tensor(org)[:,2,:,:].unsqueeze(1)
                    o4 = th.tensor(org)[:,3,:,:].unsqueeze(1)
                    c = th.tensor(cal)

                tup = (o1/o1.max(),o2/o2.max(),o3/o3.max(),o4/o4.max(),m,s,c,co)

                compose = th.cat(tup,0)
                vutils.save_image(compose, fp = args.out_dir +str(slice_ID)+'_output'+str(i)+".jpg", nrow = 1, padding = 10)
        ensres = staple(th.stack(enslist,dim=0)).squeeze(0)

        print("enslist",len(enslist))
        #print(np.unique( ensres.cpu()))
        vutils.save_image(ensres, fp = args.out_dir +str(slice_ID)+'_output_ens'+".jpg", nrow = 1, padding = 10)
        my_cou= my_cou+1
    mean_val_loss =(sum(val_losses) / len(val_losses)   )
    print("mean_val_loss",mean_val_loss)

def create_argparser():
    defaults = dict(
        data_name = 'BRATS',
        data_dir="../dataset/brats2020/testing",
        clip_denoised=True,
        num_samples=1,
        batch_size=1,
        use_ddim=False,
        model_path="",
        num_ensemble=5,      #number of samples in the ensemble
        gpu_dev = "0",
        out_dir='./results/',
        multi_gpu = None, #"0,1,2"
        debug = False
    )

    my_args = dict(
        data_name='ISIC',#'BRATS',#'ISIC',
        data_dir='/tmp/ahmed/isic',#'/tmp/ahmed/oasis/MICCAI_BraTS2020_TrainingData/BraTS20_Training_001',#'/tmp/ahmed/isic',
        out_dir='/tmp/ahmed/out',

        model_path='/tmp/ahmed/out/savedmodel000020.pt',
        num_ensemble=1,# 5,
        num_samples=1,#4,
        clip_denoised=True,

        image_size=256,
        num_channels=128,
        class_cond=False,
        num_res_blocks= 2,
        num_heads= 1,
        learn_sigma= True,
        use_scale_shift_norm= False,
        attention_resolutions= "16",
        diffusion_steps=1000,
        noise_schedule= 'linear' ,
        rescale_learned_sigmas= False,
        rescale_timesteps= False,
        lr= 1e-4,
        batch_size=1 ,# 8
        debug=False
    )
    defaults.update(model_and_diffusion_defaults())
    parser =Args()# argparse.ArgumentParser()
    add_dict_to_argparser(parser, defaults)
    add_dict_to_argparser(parser, my_args)
    print(parser.data_name)
    return parser

if __name__ == "__main__":

    main()
saisusmitha commented 1 year ago

@alqurri77 Where to add this metrics code? do we have any validation loop or code? kindly let me know where to integrate the metrics code if you know.

alqurri77 commented 1 year ago

on top of the above code this line assign the metrics: dice_loss = CrossEntropyLoss()

alqurri77 commented 1 year ago

I think the validation loop is this one, but I'm not sure:

while my_cou * args.batch_size < args.num_samples: