Closed fujistoo closed 1 year ago
During training, there shouldnt be much difference. During testing, the reconstruction of individual images will indeed take 4 times longer (~40% increased inference time overall). Right now, only one image is processed at a time and the code is not runtime optimized! Also the patching could be done in parallel to speed up the inference.
Thanks! I made some changes myself for non-medical images, but wasn't sure which part went wrong. All the reconstructed images did not get denoised...?
Resulting image comes from reco=reco_patched.clone()
after 99-epochs.
That is very hard to tell without knowing the changes you made. 99 epochs should be enough to get a somewhat meaningful reconstruction for the IXI data set. What parts did you change?
# DDPM_2D_patched.py
import torch
import numpy as np
import torchio as tio
import torch.optim as optim
import pytorch_lightning as pl
from typing import Any, List
from src.models.diffusionmodules.cond_DDPM import GaussianDiffusion
from src.models.diffusionmodules.OpenAI_Unet import UNetModel as OpenAI_UNet
from src.utils.diffusionmodules.patch_sampling import BoxSampler
from src.utils.diffusionmodules.generate_noise import gen_noise
from src.utils.diffusionmodules.utils_eval import _test_step, _test_end, get_eval_dictionary, get_eval_metrics_dictionary
import lightning as L
# metrics
from sklearn.metrics import confusion_matrix, roc_curve, accuracy_score, precision_recall_fscore_support, auc, precision_recall_curve, average_precision_score
from customized.pre_processing import Tiler
import wandb
def compute_roc(predictions, labels):
_fpr, _tpr, _ = roc_curve(labels.astype(int), predictions,pos_label=1)
roc_auc = auc(_fpr, _tpr)
return roc_auc, _fpr, _tpr, _
def compute_prc(predictions, labels):
precisions, recalls, thresholds = precision_recall_curve(labels.astype(int), predictions)
auprc = average_precision_score(labels.astype(int), predictions)
return auprc, precisions, recalls, thresholds
class DDPM_2D(L.LightningModule):
def __init__(self,cfg,prefix=None):
super().__init__()
self.cfg = cfg
# Modell
image_size = (int(cfg.get('image_size',400)),)*2 # default 400 or from config.
model = OpenAI_UNet(
image_size = image_size,
in_channels = 3,
model_channels = cfg.get('unet_dim',64),
out_channels = 3,
num_res_blocks = cfg.get('num_res_blocks',3),
# attention_resolutions = (int(cfg.imageDim[0])/int(32),int(cfg.imageDim[0])/int(16),int(cfg.imageDim[0])/int(8)),
attention_resolutions = tuple(cfg.get('att_res',[int(image_size[0]/32),int(image_size[0]/16), int(image_size[0]/8)])), # 32, 16, 8
dropout=cfg.get('dropout_unet',0), # default is 0.1
channel_mult=cfg.get('dim_mults',[1, 2, 4, 8]),
conv_resample=True,
dims=2,
num_classes=None,
use_checkpoint=True,
use_fp16=True,
num_heads=cfg.get('num_heads',1),
num_head_channels=64,
num_heads_upsample=-1,
use_scale_shift_norm=True,
resblock_updown=True,
use_new_attention_order=True,
use_spatial_transformer=False,
transformer_depth=1,
)
model.convert_to_fp16()
timesteps = cfg.get('timesteps',1000)
self.test_timesteps = cfg.get('test_timesteps',150)
sampling_timesteps = cfg.get('sampling_timesteps',self.test_timesteps)
self.diffusion = GaussianDiffusion(
model,
image_size = image_size, # only important when sampling
timesteps = timesteps, # number of steps
sampling_timesteps = sampling_timesteps,
objective = cfg.get('objective','pred_x0'), # pred_noise or pred_x0
channels = 1,
loss_type = cfg.get('loss','l1'), # L1 or L2
p2_loss_weight_gamma = cfg.get('p2_gamma',0),
inpaint = cfg.get('inpaint',False),
cfg=cfg
)
self.boxes = BoxSampler(cfg) # initialize box sampler
self.prefix = prefix
self.save_hyperparameters()
def forward(self):
return None
def training_step(self, batch, batch_idx: int):
# process batch
input = batch["image"]
# generate bboxes for DDPM
if self.cfg.get('grid_boxes',True): # sample boxes from a grid
bbox = torch.zeros([input.shape[0],4],dtype=int)
bboxes = self.boxes.sample_grid(input)
ind = torch.randint(0,bboxes.shape[1],(input.shape[0],))
for j in range(input.shape[0]):
bbox[j] = bboxes[j,ind[j]]
bbox = bbox.unsqueeze(-1)
else: # sample boxes randomly
bbox = self.boxes.sample_single_box(input)
# generate noise
if self.cfg.get('noisetype') is not None:
noise = gen_noise(self.cfg, input.shape).to(self.device)
else:
noise = None
# reconstruct
loss, reco = self.diffusion(input, box=bbox,noise=noise)
self.log(f'train/loss', loss, prog_bar=False, on_step=False, on_epoch=True, batch_size=input.shape[0],sync_dist=True)
return {"loss": loss}
def validation_step(self, batch: Any, batch_idx: int):
# input = batch['vol'][tio.DATA].squeeze(-1)
input = batch["image"]
# generate bboxes for DDPM
if self.cfg.get('grid_boxes',False): # sample boxes from a grid
bbox = torch.zeros([input.shape[0],4],dtype=int)
bboxes = self.boxes.sample_grid(input)
ind = torch.randint(0,bboxes.shape[1],(input.shape[0],))
for j in range(input.shape[0]):
bbox[j] = bboxes[j,ind[j]]
bbox = bbox.unsqueeze(-1)
else: # sample boxes randomly
bbox = self.boxes.sample_single_box(input)
# generate noise
if self.cfg.get('noisetype') is not None:
noise = gen_noise(self.cfg, input.shape).to(self.device)
else:
noise = None
loss, reco = self.diffusion(input, box=bbox, noise=noise)
self.log(f'val/loss_comb', loss, prog_bar=False, on_step=False, on_epoch=True, batch_size=input.shape[0],sync_dist=True)
return {"loss": loss}
def on_test_start(self):
self.metrics_dict = get_eval_metrics_dictionary()
# self.eval_dict = get_eval_dictionary()
# self.inds = []
# self.latentSpace_slice = []
# self.new_size = [160,190,160]
# self.diffs_list = []
# self.seg_list = []
# if not hasattr(self,'threshold'):
# self.threshold = {}
def test_step(self, batch: Any, batch_idx: int):
# self.dataset = batch['Dataset']
input = batch["image"] # 1chw
mask = batch["mask"].unsqueeze(1).expand(-1,3,-1,-1) # 1hw
# if self.cfg.get('num_eval_slices', input.size(4)) != input.size(4):
# num_slices = self.cfg.get('num_eval_slices', input.size(4)) # number of center slices to evaluate. If not set, the whole Volume is evaluated
# start_slice = int((input.size(4) - num_slices) / 2)
# input = input[...,start_slice:start_slice+num_slices]
# # data_orig = data_orig[...,start_slice:start_slice+num_slices]
# # data_seg = data_seg[...,start_slice:start_slice+num_slices]
# # data_mask = data_mask[...,start_slice:start_slice+num_slices]
# ind_offset = start_slice
# else:
# ind_offset = 0
# final_volume = torch.zeros([input.size(2), input.size(3), input.size(4)], device = self.device)
# reorder depth to batch dimension
assert input.shape[0] == 1, "Batch size must be 1"
# input = input.squeeze(0).permute(3,0,1,2) # [B,C,H,W,D] -> [D,C,H,W]
# input = input.squeeze(0).permute(1,2,0)
# latentSpace.append(torch.tensor([0],dtype=float).repeat(input.shape[0])) # dummy latent space
# generate bboxes for DDPM
bbox = self.boxes.sample_grid(input)
reco_patched = torch.zeros_like(input)
# generate noise
if self.cfg.get('noisetype') is not None:
noise = gen_noise(self.cfg, input.shape).to(self.device)
else:
noise = None
# use tiler
# tiles = self.tiler.tile(input)
# loss, reco = self.diffusion(tiles, box=bbox, noise=noise)
# over 4 tiles
for k in range(bbox.shape[1]):
box = bbox[:,k]
# reconstruct
loss_diff, reco = self.diffusion(input,t=self.test_timesteps-1, box=box,noise=noise)
if reco.shape[1] == 2:
reco = reco[:,0:1,:,:]
for j in range(reco_patched.shape[0]):
if self.cfg.get('overlap',False): # cut out the overlap region
grid = self.boxes.sample_grid_cut(input)
box_cut = grid[:,k]
if self.cfg.get('agg_overlap','cut') == 'cut': # cut out the overlap region
reco_patched[j,:,box_cut[j,1]:box_cut[j,3],box_cut[j,0]:box_cut[j,2]] = reco[j,:,box_cut[j,1]:box_cut[j,3],box_cut[j,0]:box_cut[j,2]]
elif self.cfg.get('agg_overlap','cut') == 'avg': # average the overlap region
reco_patched[j,:,box[j,1]:box[j,3],box[j,0]:box[j,2]] = reco_patched[j,:,box[j,1]:box[j,3],box[j,0]:box[j,2]] + reco[j,:,box[j,1]:box[j,3],box[j,0]:box[j,2]]
else:
reco_patched[j,:,box[j,1]:box[j,3],box[j,0]:box[j,2]] = reco[j,:,box[j,1]:box[j,3],box[j,0]:box[j,2]]
if self.cfg.get('overlap',False) and self.cfg.get('agg_overlap','cut') == 'avg': # average the intersection of all patches
mask = torch.zeros_like(reco_patched)
# create mask
for k in range(bbox.shape[1]):
box = bbox[:,k]
for l in range(mask.shape[0]):
mask[l,:,box[l,1]:box[l,3],box[l,0]:box[l,2]] = mask[l,:,box[l,1]:box[l,3],box[l,0]:box[l,2]] + 1
# divide by the mask to average the intersection of all patches
reco_patched = reco_patched/mask
reco = reco_patched.clone()
recon = reco.clone().squeeze().permute(1,2,0)
input = input.squeeze().permute(1,2,0)
diff = torch.abs(input - recon)
recon = recon.cpu().numpy() # 1,3,h,w
recon = (recon*255).astype("uint8")
AUC, _fpr, _tpr, _threshs = compute_roc(diff.cpu().flatten(), np.array(mask[0].cpu().flatten()).astype(bool))
AUPRC, _fpr, _tpr, _threshs = compute_prc(diff.cpu().flatten(), np.array(mask[0].cpu().flatten()).astype(bool))
self.metrics_dict['AUROC'].append(AUC)
self.metrics_dict['AUPRC'].append(AUPRC)
self.logger.experiment.log({"test/recon": wandb.Image(recon)})
# AnomalyScoreComb.append(loss_diff.cpu())
# AnomalyScoreReg.append(AnomalyScoreComb) # dummy
# AnomalyScoreReco.append(AnomalyScoreComb) # dummy
# # reassamble the reconstruction volume
# final_volume = reco.clone().squeeze()
# final_volume = final_volume.permute(1,2,0) # to HxWxD
# # average across slices to get volume-based scores
# self.latentSpace_slice.extend(latentSpace)
# self.eval_dict['latentSpace'].append(torch.mean(torch.stack(latentSpace),0))
# AnomalyScoreComb_vol = np.mean(AnomalyScoreComb)
# AnomalyScoreReg_vol = AnomalyScoreComb_vol # dummy
# AnomalyScoreReco_vol = AnomalyScoreComb_vol # dummy
# self.eval_dict['AnomalyScoreRegPerVol'].append(AnomalyScoreReg_vol)
# if not self.cfg.get('use_postprocessed_score', True):
# self.eval_dict['AnomalyScoreRecoPerVol'].append(AnomalyScoreReco_vol)
# self.eval_dict['AnomalyScoreCombPerVol'].append(AnomalyScoreComb_vol)
# final_volume = final_volume.unsqueeze(0)
# final_volume = final_volume.unsqueeze(0)
# # calculate metrics
# _test_step(self, final_volume, data_orig, data_seg, data_mask, batch_idx, ID, label)
# def on_test_end(self) :
# # calculate metrics
# _test_end(self) # everything that is independent of the model choice
def on_test_end(self):
self.metrics_dict['AUROC'] = np.mean(self.metrics_dict['AUROC'])
self.metrics_dict['AUPRC'] = np.mean(self.metrics_dict['AUPRC'])
print(self.metrics_dict['AUROC'])
print(self.metrics_dict['AUPRC'])
def configure_optimizers(self):
return optim.Adam(self.parameters(), lr=self.cfg.lr)
def update_prefix(self, prefix):
self.prefix = prefix
This is the DDPM_2D_patched.py
. The changes I have made so far are image_size
. I didn't rescale it. And during test_step
the num_eval_slices
part is commented out so the entire image gets passed in normally, because the images I deal with are common bchw
, unlike scalar medical images with an additional dimension. The underlying UNet and GaussianDiffusion classes and configs are left untouched. The entire idea was to have the codebase accommodate to non-scalar ("normal") images.
at first glance, i can not see an error. Does the Training work? i.e. is the loss decreasing?
Yeah it did, which is why the resulting image baffles me a bit. But it would make sense commenting out the num_eval_slices
part, right? Since that is more scalar-specific section to process different dimensions? (haven't had any encounters with scalar images yet, newbie on that)
Have you checked, if the model checkpoint gets loaded properly (given you are reevaluating)? Also maybe debug in the evaluation and look at the input and output directly at the reconstruction step.
closing this (stale)
What is the estimated training time? It seems that the
bbox
is pretty time-consuming both during training and testing. Also wanted to make sure that during testing, only one image gets processed at a time?