Closed alqurri77 closed 1 year ago
as far as I know - the model prediction mask is "sample" in the above formula. Can you give code to extract the ground truth mask for the same you mentioned?
Below is the code for sampling.py . but I'm not sure ... for example why th.tensor(sample)[:,-1,:,:].unsqueeze(1) instead of just 'sample'
from torch.nn.modules.loss import CrossEntropyLoss
import io as ahmed
import argparse
import os
from ssl import OP_NO_TLSv1
import nibabel as nib
# from visdom import Visdom
# viz = Visdom(port=8850)
import sys
import random
sys.path.append(".")
import numpy as np
import time
import torch as th
from PIL import Image
import torch.distributed as dist
'''
from guided_diffusion import dist_util, logger
from guided_diffusion.bratsloader import BRATSDataset, BRATSDataset3D
from guided_diffusion.isicloader import ISICDataset
import torchvision.utils as vutils
from guided_diffusion.utils import staple
from guided_diffusion.script_util import (
NUM_CLASSES,
model_and_diffusion_defaults,
create_model_and_diffusion,
add_dict_to_argparser,
args_to_dict,
)
'''
import torchvision.transforms as transforms
from torchsummary import summary
#--------------------
dice_loss = CrossEntropyLoss()# DiceLoss(1) #
val_losses = []
#-------------
seed=10
th.manual_seed(seed)
th.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
def visualize(img):
_min = img.min()
_max = img.max()
normalized_img = (img - _min)/ (_max - _min)
return normalized_img
def main():
args = create_argparser()#.parse_args()
setup_dist(args)
configure(dir = args.out_dir)
print("args.data_name ",args.data_name )
if args.data_name == 'ISIC':
tran_list = [transforms.Resize((args.image_size,args.image_size)), transforms.ToTensor(),]
transform_test = transforms.Compose(tran_list)
ds = ISICDataset(args, args.data_dir, transform_test, mode = 'Test')
args.in_ch = 4
elif args.data_name == 'BRATS':
tran_list = [transforms.Resize((args.image_size,args.image_size)),]
transform_test = transforms.Compose(tran_list)
ds = BRATSDataset3D(args.data_dir,transform_test)
args.in_ch = 5
datal = th.utils.data.DataLoader(
ds,
batch_size=1,
shuffle=True)
data = iter(datal)
log("creating model and diffusion...")
model, diffusion = create_model_and_diffusion(
**args_to_dict(args, model_and_diffusion_defaults().keys())
)
all_images = []
state_dict = load_state_dict(args.model_path, map_location="cpu")
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
# name = k[7:] # remove `module.`
if 'module.' in k:
new_state_dict[k[7:]] = v
# load params
else:
new_state_dict = state_dict
#----------------------------------------
model.load_state_dict(new_state_dict)
model.to(dev())
if args.use_fp16:
model.convert_to_fp16()
model.eval()
#while len(all_images) * args.batch_size < args.num_samples:
my_cou=0
while my_cou * args.batch_size < args.num_samples:
print("len(all_images)= ",len(all_images),"args.batch_size= ", args.batch_size, " args.num_samples ", args.num_samples )
b, m = next(data) #should return an image from the dataloader "data"
c = th.randn_like(b[:, :1, ...])
img = th.cat((b, c), dim=1) #add a noise channel$
if args.data_name == 'ISIC':
slice_ID="1000"#path[0].split("_")[-1].split('.')[0]
elif args.data_name == 'BRATS':
# slice_ID=path[0].split("_")[2] + "_" + path[0].split("_")[4]
slice_ID="1000"#path[0].split("_")[-3] + "_" + path[0].split("slice")[-1].split('.nii')[0]
log("sampling...")
start = th.cuda.Event(enable_timing=True)
end = th.cuda.Event(enable_timing=True)
enslist = []
for i in range(args.num_ensemble): #this is for the generation of an ensemble of 5 masks.
print("i= ",i,"args.num_ensemble= ", args.num_ensemble)
model_kwargs = {}
start.record()
sample_fn = (
diffusion.p_sample_loop_known if not args.use_ddim else diffusion.ddim_sample_loop_known
)
sample, x_noisy, org, cal, cal_out = sample_fn(
model,
(args.batch_size, 3, args.image_size, args.image_size), img,
step = args.diffusion_steps,
clip_denoised=args.clip_denoised,
model_kwargs=model_kwargs,
)
end.record()
th.cuda.synchronize()
print('time for 1 sample', start.elapsed_time(end)) #time measurement for the generation of 1 sample
co = th.tensor(cal_out)
enslist.append(co)
#-------------------------------------------
print("sample",sample.shape)
print("org",th.tensor(org)[:,:-1,:,:].shape)
#val_loss =dice_loss (sample, th.tensor(org)[:,:-1,:,:], softmax=True)
sample2=th.tensor(sample)[:,-1,:,:].unsqueeze(1)
print("sample2",sample2.shape)
target =m.cpu()#th.tensor(org)[:,:-1,:,:]# torch.argmax(th.tensor(org)[:,:-1,:,:], dim=1)#th.tensor(org)[:,:-1,:,:]# torch.argmax(th.tensor(org)[:,:-1,:,:], dim=1)
print("target",target.shape)
val_loss = dice_loss (sample2.cpu(),target [:] )# .long() )# softmax=True) #
val_losses.append(val_loss.item())
#-------------------------------------
if args.debug:
# print('sample size is',sample.size())
# print('org size is',org.size())
# print('cal size is',cal.size())
if args.data_name == 'ISIC':
s = th.tensor(sample)[:,-1,:,:].unsqueeze(1).repeat(1, 3, 1, 1)
o = th.tensor(org)[:,:-1,:,:]
c = th.tensor(cal).repeat(1, 3, 1, 1)
co = co.repeat(1, 3, 1, 1)
print("o",o.shape)
print("s",s.shape)
print("c",c.shape)
print("co",co.shape)
elif args.data_name == 'BRATS':
s = th.tensor(sample)[:,-1,:,:].unsqueeze(1)
m = th.tensor(m.to(device = 'cuda:0'))[:,0,:,:].unsqueeze(1)
o1 = th.tensor(org)[:,0,:,:].unsqueeze(1)
o2 = th.tensor(org)[:,1,:,:].unsqueeze(1)
o3 = th.tensor(org)[:,2,:,:].unsqueeze(1)
o4 = th.tensor(org)[:,3,:,:].unsqueeze(1)
c = th.tensor(cal)
tup = (o1/o1.max(),o2/o2.max(),o3/o3.max(),o4/o4.max(),m,s,c,co)
compose = th.cat(tup,0)
vutils.save_image(compose, fp = args.out_dir +str(slice_ID)+'_output'+str(i)+".jpg", nrow = 1, padding = 10)
ensres = staple(th.stack(enslist,dim=0)).squeeze(0)
print("enslist",len(enslist))
#print(np.unique( ensres.cpu()))
vutils.save_image(ensres, fp = args.out_dir +str(slice_ID)+'_output_ens'+".jpg", nrow = 1, padding = 10)
my_cou= my_cou+1
mean_val_loss =(sum(val_losses) / len(val_losses) )
print("mean_val_loss",mean_val_loss)
def create_argparser():
defaults = dict(
data_name = 'BRATS',
data_dir="../dataset/brats2020/testing",
clip_denoised=True,
num_samples=1,
batch_size=1,
use_ddim=False,
model_path="",
num_ensemble=5, #number of samples in the ensemble
gpu_dev = "0",
out_dir='./results/',
multi_gpu = None, #"0,1,2"
debug = False
)
my_args = dict(
data_name='ISIC',#'BRATS',#'ISIC',
data_dir='/tmp/ahmed/isic',#'/tmp/ahmed/oasis/MICCAI_BraTS2020_TrainingData/BraTS20_Training_001',#'/tmp/ahmed/isic',
out_dir='/tmp/ahmed/out',
model_path='/tmp/ahmed/out/savedmodel000020.pt',
num_ensemble=1,# 5,
num_samples=1,#4,
clip_denoised=True,
image_size=256,
num_channels=128,
class_cond=False,
num_res_blocks= 2,
num_heads= 1,
learn_sigma= True,
use_scale_shift_norm= False,
attention_resolutions= "16",
diffusion_steps=1000,
noise_schedule= 'linear' ,
rescale_learned_sigmas= False,
rescale_timesteps= False,
lr= 1e-4,
batch_size=1 ,# 8
debug=False
)
defaults.update(model_and_diffusion_defaults())
parser =Args()# argparse.ArgumentParser()
add_dict_to_argparser(parser, defaults)
add_dict_to_argparser(parser, my_args)
print(parser.data_name)
return parser
if __name__ == "__main__":
main()
@alqurri77 Where to add this metrics code? do we have any validation loop or code? kindly let me know where to integrate the metrics code if you know.
on top of the above code this line assign the metrics:
dice_loss = CrossEntropyLoss()
I think the validation loop is this one, but I'm not sure:
while my_cou * args.batch_size < args.num_samples:
Hi;
I need to calculate the model accuracy (example dice loss). Hence, I need the model predication and grounds truth. What is the model predication out of those? : sample, x_noisy, org, cal, cal_out What each one means.