Open haomo-bh opened 1 month ago
Hi, could you please check whether the training images are affine-registered? Our code is designed for deformable registration, which requires the image to be pre-registered via affine registration. It seems that TransMorph uses a separate affine network, followed by the core TransMorph network. If you directly put the non-affine-registered images into the CorrMLP, the training will be very unstable.
Hello, thank you very much for your timely response. The dataset OASIS I am using has been pre-rigidly aligned, and I have not experienced the problem of Dice fluctuations when using the same data to train with TransMatch. Could you please help me check if there may be an understanding error in my code regarding your code. Due to limitations of graphics memory, I have cropped the data according to the parameters in your paper. I wonder if this may have some impact. Finally, thank you again for your prompt answer. My question may have been rushed and I was not prepared as I expected. I apologize for any inconvenience caused. Here is my training script:
from torch.utils.tensorboard import SummaryWriter
import os, glob
from CorrMLP import utils, losses
from CorrMLP.model import CorrMLP, SpatialTransformer_block
import sys
from torch.utils.data import DataLoader
from datas import datasets, trans
import numpy as np
import torch
from torchvision import transforms
from torch import optim
import torch.nn as nn
import matplotlib.pyplot as plt
from natsort import natsorted
class Logger(object):
def __init__(self, save_dir):
self.terminal = sys.stdout
self.log = open(save_dir+"logfile.log", "a")
def write(self, message):
self.terminal.write(message)
self.log.write(message)
def flush(self):
pass
def Dice(vol1, vol2, labels=None, nargout=1):
if labels is None:
labels = np.unique(np.concatenate((vol1, vol2)))
labels = np.delete(labels, np.where(labels == 0)) # remove background
dicem = np.zeros(len(labels))
for idx, lab in enumerate(labels):
vol1l = vol1 == lab
vol2l = vol2 == lab
top = 2 * np.sum(np.logical_and(vol1l, vol2l))
bottom = np.sum(vol1l) + np.sum(vol2l)
bottom = np.maximum(bottom, np.finfo(float).eps) # add epsilon.
dicem[idx] = top / bottom
if nargout == 1:
return dicem
else:
return (dicem, labels)
def NJD(displacement):
D_y = (displacement[1:,:-1,:-1,:] - displacement[:-1,:-1,:-1,:])
D_x = (displacement[:-1,1:,:-1,:] - displacement[:-1,:-1,:-1,:])
D_z = (displacement[:-1,:-1,1:,:] - displacement[:-1,:-1,:-1,:])
D1 = (D_x[...,0]+1)*( (D_y[...,1]+1)*(D_z[...,2]+1) - D_z[...,1]*D_y[...,2])
D2 = (D_x[...,1])*(D_y[...,0]*(D_z[...,2]+1) - D_y[...,2]*D_x[...,0])
D3 = (D_x[...,2])*(D_y[...,0]*D_z[...,1] - (D_y[...,1]+1)*D_z[...,0])
Ja_value = D1-D2+D3
return np.sum(Ja_value<0)
def main():
batch_size = 1
Weights = [1.0, 1.0]
img_size = (144, 192, 160)
save_dir = 'CorrMLP_ncc_{}_diffusion_{}/'.format(Weights[0], Weights[1])
if not os.path.exists('experiments/'+save_dir):
os.makedirs('experiments/'+save_dir)
if not os.path.exists('logs/'+save_dir):
os.makedirs('logs/'+save_dir)
sys.stdout = Logger('logs/'+save_dir)
train_dir = r'/home/mh/PythonCodes/OASIS_L2R_2021_task03/All/'
val_dir = '/home/mh/PythonCodes/OASIS_L2R_2021_task03/Test/'
lr = 1e-4 # learning rate
epoch_start = 0
max_epoch = 500 #max traning epoch
cont_training = False #if continue training
'''
Initialize model
'''
model = CorrMLP()
model.cuda()
# transfer model
SpatialTransformer = SpatialTransformer_block(mode='nearest')
SpatialTransformer.cuda()
SpatialTransformer.eval()
'''
Initialize training
'''
train_composed = transforms.Compose([trans.NumpyType((np.float32, np.float32))])
val_composed = transforms.Compose([trans.NumpyType((np.float32, np.int16))])
train_set = datasets.OASISBrainDataset(glob.glob(train_dir + '*.pkl'), transforms=train_composed)
val_set = datasets.OASISBrainInferDataset(glob.glob(val_dir + '*.pkl'), transforms=val_composed)
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_set, batch_size=1, shuffle=False, num_workers=4, pin_memory=True, drop_last=True)
optimizer = optim.Adam(model.parameters(), lr=lr)
Losses = [losses.NCC(win=9).loss, losses.Grad('l2').loss]
best_dsc = 0
writer = SummaryWriter(log_dir='logs/'+save_dir)
for epoch in range(epoch_start, max_epoch):
print('Training Starts')
'''
Training
'''
loss_all = utils.AverageMeter()
idx = 0
for data in train_loader:
idx += 1
model.train()
# adjust_learning_rate(optimizer, epoch, max_epoch, lr)
data = [t.cuda() for t in data]
x = data[0][:, :, 8:152, :, 32:192]
y = data[1][:, :, 8:152, :, 32:192]
# x_in = torch.cat((x,y), dim=1)
output, flow = model(y, x)
loss_ncc = Losses[0](y, output) * Weights[0]
loss_reg = Losses[1](np.zeros((1)), flow) * Weights[1]
loss = loss_ncc + loss_reg
loss_all.update(loss.item(), y.numel())
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()
print('Iter {} of {} loss {:.4f}, Img Sim: {:.6f}, Reg: {:.6f}'.format(idx, len(train_loader),
loss.item(),
loss_ncc.item(),
loss_reg.item()))
writer.add_scalar('Loss/train', loss_all.avg, epoch)
print('Epoch {} loss {:.4f}'.format(epoch, loss_all.avg))
'''
Validation
'''
eval_dsc = utils.AverageMeter()
with torch.no_grad():
for data in val_loader:
model.eval()
data = [t.cuda() for t in data]
x = data[0][:, :, 8:152, :, 32:192]
y = data[1][:, :, 8:152, :, 32:192]
x_seg = data[2][:, :, 8:152, :, 32:192]
y_seg = data[3][:, :, 8:152, :, 32:192]
# x_in = torch.cat((x, y), dim=1)
grid_img = mk_grid_img(8, 1, img_size)
output = model(y, x)
def_out = SpatialTransformer(x_seg.cuda().float(), output[1].cuda())
def_grid = SpatialTransformer(grid_img.float(), output[1].cuda())
dsc = utils.dice_val_VOI(def_out.long(), y_seg.long())
eval_dsc.update(dsc.item(), x.size(0))
print(eval_dsc.avg)
best_dsc = max(eval_dsc.avg, best_dsc)
save_checkpoint({
'epoch': epoch + 1,
'state_dict': model.state_dict(),
'best_dsc': best_dsc,
'optimizer': optimizer.state_dict(),
}, save_dir='experiments/'+save_dir, filename='dsc{:.4f}.pth.tar'.format(eval_dsc.avg))
writer.add_scalar('DSC/validate', eval_dsc.avg, epoch)
plt.switch_backend('agg')
pred_fig = comput_fig(def_out)
grid_fig = comput_fig(def_grid)
x_fig = comput_fig(x_seg)
tar_fig = comput_fig(y_seg)
writer.add_figure('Grid', grid_fig, epoch)
plt.close(grid_fig)
writer.add_figure('input', x_fig, epoch)
plt.close(x_fig)
writer.add_figure('ground truth', tar_fig, epoch)
plt.close(tar_fig)
writer.add_figure('prediction', pred_fig, epoch)
plt.close(pred_fig)
loss_all.reset()
del def_out, def_grid, grid_img, output
writer.close()
def comput_fig(img):
img = img.detach().cpu().numpy()[0, 0, 48:64, :, :]
fig = plt.figure(figsize=(12,12), dpi=180)
for i in range(img.shape[0]):
plt.subplot(4, 4, i + 1)
plt.axis('off')
plt.imshow(img[i, :, :], cmap='gray')
fig.subplots_adjust(wspace=0, hspace=0)
return fig
def adjust_learning_rate(optimizer, epoch, MAX_EPOCHES, INIT_LR, power=0.9):
for param_group in optimizer.param_groups:
param_group['lr'] = round(INIT_LR * np.power( 1 - (epoch) / MAX_EPOCHES ,power),8)
def mk_grid_img(grid_step, line_thickness=1, grid_sz=(160, 192, 224)):
grid_img = np.zeros(grid_sz)
for j in range(0, grid_img.shape[1], grid_step):
grid_img[:, j+line_thickness-1, :] = 1
for i in range(0, grid_img.shape[2], grid_step):
grid_img[:, :, i+line_thickness-1] = 1
grid_img = grid_img[None, None, ...]
grid_img = torch.from_numpy(grid_img).cuda()
return grid_img
def save_checkpoint(state, save_dir='models', filename='checkpoint.pth.tar', max_model_num=8):
torch.save(state, save_dir+filename)
model_lists = natsorted(glob.glob(save_dir + '*'))
while len(model_lists) > max_model_num:
os.remove(model_lists[0])
model_lists = natsorted(glob.glob(save_dir + '*'))
if __name__ == '__main__':
'''
GPU configuration
'''
GPU_iden = 0
GPU_num = torch.cuda.device_count()
print('Number of GPU: ' + str(GPU_num))
for GPU_idx in range(GPU_num):
GPU_name = torch.cuda.get_device_name(GPU_idx)
print(' GPU #' + str(GPU_idx) + ': ' + GPU_name)
torch.cuda.set_device(GPU_iden)
GPU_avai = torch.cuda.is_available()
print('Currently using: ' + torch.cuda.get_device_name(GPU_iden))
print('If the GPU is available? ' + str(GPU_avai))
torch.manual_seed(0)
main()
Hello, thank you for your open source code. Since your dataset is not open yet, I used TransMorph's data reading pipeline to replace your data reading. After running for about one day, I found that the results did not meet my expectations. The loss keeps fluctuating up and down, and the Dice after registration also keeps fluctuating. Can you tell me what special method is used for data reading??? Add: I used OASIS dataset for test
![image](https://github.com/MungoMeng/Registration-CorrMLP/assets/166362962/52592978-61f4-435b-b7bb-9b0695638167)