Open haomo-bh opened 5 months ago
Hi, could you please check whether the training images are affine-registered? Our code is designed for deformable registration, which requires the image to be pre-registered via affine registration. It seems that TransMorph uses a separate affine network, followed by the core TransMorph network. If you directly put the non-affine-registered images into the CorrMLP, the training will be very unstable.
Hello, thank you very much for your timely response. The dataset OASIS I am using has been pre-rigidly aligned, and I have not experienced the problem of Dice fluctuations when using the same data to train with TransMatch. Could you please help me check if there may be an understanding error in my code regarding your code. Due to limitations of graphics memory, I have cropped the data according to the parameters in your paper. I wonder if this may have some impact. Finally, thank you again for your prompt answer. My question may have been rushed and I was not prepared as I expected. I apologize for any inconvenience caused. Here is my training script:
from torch.utils.tensorboard import SummaryWriter
import os, glob
from CorrMLP import utils, losses
from CorrMLP.model import CorrMLP, SpatialTransformer_block
import sys
from torch.utils.data import DataLoader
from datas import datasets, trans
import numpy as np
import torch
from torchvision import transforms
from torch import optim
import torch.nn as nn
import matplotlib.pyplot as plt
from natsort import natsorted
class Logger(object):
def __init__(self, save_dir):
self.terminal = sys.stdout
self.log = open(save_dir+"logfile.log", "a")
def write(self, message):
self.terminal.write(message)
self.log.write(message)
def flush(self):
pass
def Dice(vol1, vol2, labels=None, nargout=1):
if labels is None:
labels = np.unique(np.concatenate((vol1, vol2)))
labels = np.delete(labels, np.where(labels == 0)) # remove background
dicem = np.zeros(len(labels))
for idx, lab in enumerate(labels):
vol1l = vol1 == lab
vol2l = vol2 == lab
top = 2 * np.sum(np.logical_and(vol1l, vol2l))
bottom = np.sum(vol1l) + np.sum(vol2l)
bottom = np.maximum(bottom, np.finfo(float).eps) # add epsilon.
dicem[idx] = top / bottom
if nargout == 1:
return dicem
else:
return (dicem, labels)
def NJD(displacement):
D_y = (displacement[1:,:-1,:-1,:] - displacement[:-1,:-1,:-1,:])
D_x = (displacement[:-1,1:,:-1,:] - displacement[:-1,:-1,:-1,:])
D_z = (displacement[:-1,:-1,1:,:] - displacement[:-1,:-1,:-1,:])
D1 = (D_x[...,0]+1)*( (D_y[...,1]+1)*(D_z[...,2]+1) - D_z[...,1]*D_y[...,2])
D2 = (D_x[...,1])*(D_y[...,0]*(D_z[...,2]+1) - D_y[...,2]*D_x[...,0])
D3 = (D_x[...,2])*(D_y[...,0]*D_z[...,1] - (D_y[...,1]+1)*D_z[...,0])
Ja_value = D1-D2+D3
return np.sum(Ja_value<0)
def main():
batch_size = 1
Weights = [1.0, 1.0]
img_size = (144, 192, 160)
save_dir = 'CorrMLP_ncc_{}_diffusion_{}/'.format(Weights[0], Weights[1])
if not os.path.exists('experiments/'+save_dir):
os.makedirs('experiments/'+save_dir)
if not os.path.exists('logs/'+save_dir):
os.makedirs('logs/'+save_dir)
sys.stdout = Logger('logs/'+save_dir)
train_dir = r'/home/mh/PythonCodes/OASIS_L2R_2021_task03/All/'
val_dir = '/home/mh/PythonCodes/OASIS_L2R_2021_task03/Test/'
lr = 1e-4 # learning rate
epoch_start = 0
max_epoch = 500 #max traning epoch
cont_training = False #if continue training
'''
Initialize model
'''
model = CorrMLP()
model.cuda()
# transfer model
SpatialTransformer = SpatialTransformer_block(mode='nearest')
SpatialTransformer.cuda()
SpatialTransformer.eval()
'''
Initialize training
'''
train_composed = transforms.Compose([trans.NumpyType((np.float32, np.float32))])
val_composed = transforms.Compose([trans.NumpyType((np.float32, np.int16))])
train_set = datasets.OASISBrainDataset(glob.glob(train_dir + '*.pkl'), transforms=train_composed)
val_set = datasets.OASISBrainInferDataset(glob.glob(val_dir + '*.pkl'), transforms=val_composed)
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_set, batch_size=1, shuffle=False, num_workers=4, pin_memory=True, drop_last=True)
optimizer = optim.Adam(model.parameters(), lr=lr)
Losses = [losses.NCC(win=9).loss, losses.Grad('l2').loss]
best_dsc = 0
writer = SummaryWriter(log_dir='logs/'+save_dir)
for epoch in range(epoch_start, max_epoch):
print('Training Starts')
'''
Training
'''
loss_all = utils.AverageMeter()
idx = 0
for data in train_loader:
idx += 1
model.train()
# adjust_learning_rate(optimizer, epoch, max_epoch, lr)
data = [t.cuda() for t in data]
x = data[0][:, :, 8:152, :, 32:192]
y = data[1][:, :, 8:152, :, 32:192]
# x_in = torch.cat((x,y), dim=1)
output, flow = model(y, x)
loss_ncc = Losses[0](y, output) * Weights[0]
loss_reg = Losses[1](np.zeros((1)), flow) * Weights[1]
loss = loss_ncc + loss_reg
loss_all.update(loss.item(), y.numel())
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()
print('Iter {} of {} loss {:.4f}, Img Sim: {:.6f}, Reg: {:.6f}'.format(idx, len(train_loader),
loss.item(),
loss_ncc.item(),
loss_reg.item()))
writer.add_scalar('Loss/train', loss_all.avg, epoch)
print('Epoch {} loss {:.4f}'.format(epoch, loss_all.avg))
'''
Validation
'''
eval_dsc = utils.AverageMeter()
with torch.no_grad():
for data in val_loader:
model.eval()
data = [t.cuda() for t in data]
x = data[0][:, :, 8:152, :, 32:192]
y = data[1][:, :, 8:152, :, 32:192]
x_seg = data[2][:, :, 8:152, :, 32:192]
y_seg = data[3][:, :, 8:152, :, 32:192]
# x_in = torch.cat((x, y), dim=1)
grid_img = mk_grid_img(8, 1, img_size)
output = model(y, x)
def_out = SpatialTransformer(x_seg.cuda().float(), output[1].cuda())
def_grid = SpatialTransformer(grid_img.float(), output[1].cuda())
dsc = utils.dice_val_VOI(def_out.long(), y_seg.long())
eval_dsc.update(dsc.item(), x.size(0))
print(eval_dsc.avg)
best_dsc = max(eval_dsc.avg, best_dsc)
save_checkpoint({
'epoch': epoch + 1,
'state_dict': model.state_dict(),
'best_dsc': best_dsc,
'optimizer': optimizer.state_dict(),
}, save_dir='experiments/'+save_dir, filename='dsc{:.4f}.pth.tar'.format(eval_dsc.avg))
writer.add_scalar('DSC/validate', eval_dsc.avg, epoch)
plt.switch_backend('agg')
pred_fig = comput_fig(def_out)
grid_fig = comput_fig(def_grid)
x_fig = comput_fig(x_seg)
tar_fig = comput_fig(y_seg)
writer.add_figure('Grid', grid_fig, epoch)
plt.close(grid_fig)
writer.add_figure('input', x_fig, epoch)
plt.close(x_fig)
writer.add_figure('ground truth', tar_fig, epoch)
plt.close(tar_fig)
writer.add_figure('prediction', pred_fig, epoch)
plt.close(pred_fig)
loss_all.reset()
del def_out, def_grid, grid_img, output
writer.close()
def comput_fig(img):
img = img.detach().cpu().numpy()[0, 0, 48:64, :, :]
fig = plt.figure(figsize=(12,12), dpi=180)
for i in range(img.shape[0]):
plt.subplot(4, 4, i + 1)
plt.axis('off')
plt.imshow(img[i, :, :], cmap='gray')
fig.subplots_adjust(wspace=0, hspace=0)
return fig
def adjust_learning_rate(optimizer, epoch, MAX_EPOCHES, INIT_LR, power=0.9):
for param_group in optimizer.param_groups:
param_group['lr'] = round(INIT_LR * np.power( 1 - (epoch) / MAX_EPOCHES ,power),8)
def mk_grid_img(grid_step, line_thickness=1, grid_sz=(160, 192, 224)):
grid_img = np.zeros(grid_sz)
for j in range(0, grid_img.shape[1], grid_step):
grid_img[:, j+line_thickness-1, :] = 1
for i in range(0, grid_img.shape[2], grid_step):
grid_img[:, :, i+line_thickness-1] = 1
grid_img = grid_img[None, None, ...]
grid_img = torch.from_numpy(grid_img).cuda()
return grid_img
def save_checkpoint(state, save_dir='models', filename='checkpoint.pth.tar', max_model_num=8):
torch.save(state, save_dir+filename)
model_lists = natsorted(glob.glob(save_dir + '*'))
while len(model_lists) > max_model_num:
os.remove(model_lists[0])
model_lists = natsorted(glob.glob(save_dir + '*'))
if __name__ == '__main__':
'''
GPU configuration
'''
GPU_iden = 0
GPU_num = torch.cuda.device_count()
print('Number of GPU: ' + str(GPU_num))
for GPU_idx in range(GPU_num):
GPU_name = torch.cuda.get_device_name(GPU_idx)
print(' GPU #' + str(GPU_idx) + ': ' + GPU_name)
torch.cuda.set_device(GPU_iden)
GPU_avai = torch.cuda.is_available()
print('Currently using: ' + torch.cuda.get_device_name(GPU_iden))
print('If the GPU is available? ' + str(GPU_avai))
torch.manual_seed(0)
main()
Here are two examples of the training images I used in our experiments. Could you please manually check whether there are significant differences? Image.zip
您好,非常感谢您的及时回复。我使用的数据集 OASIS 已经预先严格对齐,并且在使用相同的数据通过 TransMatch 进行训练时,我没有遇到过 Dice 波动的问题。您能否帮助我检查一下我的代码中是否存在关于您的代码的理解错误。由于图形内存的限制,我根据您论文中的参数裁剪了数据。我想知道这是否会产生一些影响。最后,再次感谢您的及时回答。我的问题可能很匆忙,我没有像预期的那样做好准备。对于给您带来的任何不便,我深表歉意。这是我的训练脚本:
from torch.utils.tensorboard import SummaryWriter import os, glob from CorrMLP import utils, losses from CorrMLP.model import CorrMLP, SpatialTransformer_block import sys from torch.utils.data import DataLoader from datas import datasets, trans import numpy as np import torch from torchvision import transforms from torch import optim import torch.nn as nn import matplotlib.pyplot as plt from natsort import natsorted class Logger(object): def __init__(self, save_dir): self.terminal = sys.stdout self.log = open(save_dir+"logfile.log", "a") def write(self, message): self.terminal.write(message) self.log.write(message) def flush(self): pass def Dice(vol1, vol2, labels=None, nargout=1): if labels is None: labels = np.unique(np.concatenate((vol1, vol2))) labels = np.delete(labels, np.where(labels == 0)) # remove background dicem = np.zeros(len(labels)) for idx, lab in enumerate(labels): vol1l = vol1 == lab vol2l = vol2 == lab top = 2 * np.sum(np.logical_and(vol1l, vol2l)) bottom = np.sum(vol1l) + np.sum(vol2l) bottom = np.maximum(bottom, np.finfo(float).eps) # add epsilon. dicem[idx] = top / bottom if nargout == 1: return dicem else: return (dicem, labels) def NJD(displacement): D_y = (displacement[1:,:-1,:-1,:] - displacement[:-1,:-1,:-1,:]) D_x = (displacement[:-1,1:,:-1,:] - displacement[:-1,:-1,:-1,:]) D_z = (displacement[:-1,:-1,1:,:] - displacement[:-1,:-1,:-1,:]) D1 = (D_x[...,0]+1)*( (D_y[...,1]+1)*(D_z[...,2]+1) - D_z[...,1]*D_y[...,2]) D2 = (D_x[...,1])*(D_y[...,0]*(D_z[...,2]+1) - D_y[...,2]*D_x[...,0]) D3 = (D_x[...,2])*(D_y[...,0]*D_z[...,1] - (D_y[...,1]+1)*D_z[...,0]) Ja_value = D1-D2+D3 return np.sum(Ja_value<0) def main(): batch_size = 1 Weights = [1.0, 1.0] img_size = (144, 192, 160) save_dir = 'CorrMLP_ncc_{}_diffusion_{}/'.format(Weights[0], Weights[1]) if not os.path.exists('experiments/'+save_dir): os.makedirs('experiments/'+save_dir) if not os.path.exists('logs/'+save_dir): os.makedirs('logs/'+save_dir) sys.stdout = Logger('logs/'+save_dir) train_dir = r'/home/mh/PythonCodes/OASIS_L2R_2021_task03/All/' val_dir = '/home/mh/PythonCodes/OASIS_L2R_2021_task03/Test/' lr = 1e-4 # learning rate epoch_start = 0 max_epoch = 500 #max traning epoch cont_training = False #if continue training ''' Initialize model ''' model = CorrMLP() model.cuda() # transfer model SpatialTransformer = SpatialTransformer_block(mode='nearest') SpatialTransformer.cuda() SpatialTransformer.eval() ''' Initialize training ''' train_composed = transforms.Compose([trans.NumpyType((np.float32, np.float32))]) val_composed = transforms.Compose([trans.NumpyType((np.float32, np.int16))]) train_set = datasets.OASISBrainDataset(glob.glob(train_dir + '*.pkl'), transforms=train_composed) val_set = datasets.OASISBrainInferDataset(glob.glob(val_dir + '*.pkl'), transforms=val_composed) train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True) val_loader = DataLoader(val_set, batch_size=1, shuffle=False, num_workers=4, pin_memory=True, drop_last=True) optimizer = optim.Adam(model.parameters(), lr=lr) Losses = [losses.NCC(win=9).loss, losses.Grad('l2').loss] best_dsc = 0 writer = SummaryWriter(log_dir='logs/'+save_dir) for epoch in range(epoch_start, max_epoch): print('Training Starts') ''' Training ''' loss_all = utils.AverageMeter() idx = 0 for data in train_loader: idx += 1 model.train() # adjust_learning_rate(optimizer, epoch, max_epoch, lr) data = [t.cuda() for t in data] x = data[0][:, :, 8:152, :, 32:192] y = data[1][:, :, 8:152, :, 32:192] # x_in = torch.cat((x,y), dim=1) output, flow = model(y, x) loss_ncc = Losses[0](y, output) * Weights[0] loss_reg = Losses[1](np.zeros((1)), flow) * Weights[1] loss = loss_ncc + loss_reg loss_all.update(loss.item(), y.numel()) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() print('Iter {} of {} loss {:.4f}, Img Sim: {:.6f}, Reg: {:.6f}'.format(idx, len(train_loader), loss.item(), loss_ncc.item(), loss_reg.item())) writer.add_scalar('Loss/train', loss_all.avg, epoch) print('Epoch {} loss {:.4f}'.format(epoch, loss_all.avg)) ''' Validation ''' eval_dsc = utils.AverageMeter() with torch.no_grad(): for data in val_loader: model.eval() data = [t.cuda() for t in data] x = data[0][:, :, 8:152, :, 32:192] y = data[1][:, :, 8:152, :, 32:192] x_seg = data[2][:, :, 8:152, :, 32:192] y_seg = data[3][:, :, 8:152, :, 32:192] # x_in = torch.cat((x, y), dim=1) grid_img = mk_grid_img(8, 1, img_size) output = model(y, x) def_out = SpatialTransformer(x_seg.cuda().float(), output[1].cuda()) def_grid = SpatialTransformer(grid_img.float(), output[1].cuda()) dsc = utils.dice_val_VOI(def_out.long(), y_seg.long()) eval_dsc.update(dsc.item(), x.size(0)) print(eval_dsc.avg) best_dsc = max(eval_dsc.avg, best_dsc) save_checkpoint({ 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'best_dsc': best_dsc, 'optimizer': optimizer.state_dict(), }, save_dir='experiments/'+save_dir, filename='dsc{:.4f}.pth.tar'.format(eval_dsc.avg)) writer.add_scalar('DSC/validate', eval_dsc.avg, epoch) plt.switch_backend('agg') pred_fig = comput_fig(def_out) grid_fig = comput_fig(def_grid) x_fig = comput_fig(x_seg) tar_fig = comput_fig(y_seg) writer.add_figure('Grid', grid_fig, epoch) plt.close(grid_fig) writer.add_figure('input', x_fig, epoch) plt.close(x_fig) writer.add_figure('ground truth', tar_fig, epoch) plt.close(tar_fig) writer.add_figure('prediction', pred_fig, epoch) plt.close(pred_fig) loss_all.reset() del def_out, def_grid, grid_img, output writer.close() def comput_fig(img): img = img.detach().cpu().numpy()[0, 0, 48:64, :, :] fig = plt.figure(figsize=(12,12), dpi=180) for i in range(img.shape[0]): plt.subplot(4, 4, i + 1) plt.axis('off') plt.imshow(img[i, :, :], cmap='gray') fig.subplots_adjust(wspace=0, hspace=0) return fig def adjust_learning_rate(optimizer, epoch, MAX_EPOCHES, INIT_LR, power=0.9): for param_group in optimizer.param_groups: param_group['lr'] = round(INIT_LR * np.power( 1 - (epoch) / MAX_EPOCHES ,power),8) def mk_grid_img(grid_step, line_thickness=1, grid_sz=(160, 192, 224)): grid_img = np.zeros(grid_sz) for j in range(0, grid_img.shape[1], grid_step): grid_img[:, j+line_thickness-1, :] = 1 for i in range(0, grid_img.shape[2], grid_step): grid_img[:, :, i+line_thickness-1] = 1 grid_img = grid_img[None, None, ...] grid_img = torch.from_numpy(grid_img).cuda() return grid_img def save_checkpoint(state, save_dir='models', filename='checkpoint.pth.tar', max_model_num=8): torch.save(state, save_dir+filename) model_lists = natsorted(glob.glob(save_dir + '*')) while len(model_lists) > max_model_num: os.remove(model_lists[0]) model_lists = natsorted(glob.glob(save_dir + '*')) if __name__ == '__main__': ''' GPU configuration ''' GPU_iden = 0 GPU_num = torch.cuda.device_count() print('Number of GPU: ' + str(GPU_num)) for GPU_idx in range(GPU_num): GPU_name = torch.cuda.get_device_name(GPU_idx) print(' GPU #' + str(GPU_idx) + ': ' + GPU_name) torch.cuda.set_device(GPU_iden) GPU_avai = torch.cuda.is_available() print('Currently using: ' + torch.cuda.get_device_name(GPU_iden)) print('If the GPU is available? ' + str(GPU_avai)) torch.manual_seed(0) main()
Hello author, I would like to ask you for a trimmed OASIS dataset, because the data I trimmed before is not suitable!
Hi, I checked your code and everything looks fine. What is your Pytorch version? I used Pytorch 1.13 with Cude 11.7.
I am afraid that the calculation of 3D correlation is unstable when using different Pytorch version. Can you set use_corr=False in the MLP_decoder and see whether the training becomes more stable?
Hello, thank you for your open source code. Since your dataset is not open yet, I used TransMorph's data reading pipeline to replace your data reading. After running for about one day, I found that the results did not meet my expectations. The loss keeps fluctuating up and down, and the Dice after registration also keeps fluctuating. Can you tell me what special method is used for data reading??? Add: I used OASIS dataset for test