MungoMeng / Registration-CorrMLP

[CVPR2024 Oral && Best Paper Candidate] CorrMLP: Correlation-aware MLP-based networks for deformable medical image registration
GNU General Public License v3.0
41 stars 2 forks source link

Dataset requirements #3

Open haomo-bh opened 1 month ago

haomo-bh commented 1 month ago

Hello, thank you for your open source code. Since your dataset is not open yet, I used TransMorph's data reading pipeline to replace your data reading. After running for about one day, I found that the results did not meet my expectations. The loss keeps fluctuating up and down, and the Dice after registration also keeps fluctuating. Can you tell me what special method is used for data reading??? Add: I used OASIS dataset for test image image

MungoMeng commented 1 month ago

Hi, could you please check whether the training images are affine-registered? Our code is designed for deformable registration, which requires the image to be pre-registered via affine registration. It seems that TransMorph uses a separate affine network, followed by the core TransMorph network. If you directly put the non-affine-registered images into the CorrMLP, the training will be very unstable.

haomo-bh commented 1 month ago

Hello, thank you very much for your timely response. The dataset OASIS I am using has been pre-rigidly aligned, and I have not experienced the problem of Dice fluctuations when using the same data to train with TransMatch. Could you please help me check if there may be an understanding error in my code regarding your code. Due to limitations of graphics memory, I have cropped the data according to the parameters in your paper. I wonder if this may have some impact. Finally, thank you again for your prompt answer. My question may have been rushed and I was not prepared as I expected. I apologize for any inconvenience caused. Here is my training script:

from torch.utils.tensorboard import SummaryWriter
import os, glob
from CorrMLP import utils, losses
from CorrMLP.model import CorrMLP, SpatialTransformer_block
import sys
from torch.utils.data import DataLoader
from datas import datasets, trans
import numpy as np
import torch
from torchvision import transforms
from torch import optim
import torch.nn as nn
import matplotlib.pyplot as plt
from natsort import natsorted

class Logger(object):
    def __init__(self, save_dir):
        self.terminal = sys.stdout
        self.log = open(save_dir+"logfile.log", "a")

    def write(self, message):
        self.terminal.write(message)
        self.log.write(message)

    def flush(self):
        pass

def Dice(vol1, vol2, labels=None, nargout=1):

    if labels is None:
        labels = np.unique(np.concatenate((vol1, vol2)))
        labels = np.delete(labels, np.where(labels == 0))  # remove background

    dicem = np.zeros(len(labels))
    for idx, lab in enumerate(labels):
        vol1l = vol1 == lab
        vol2l = vol2 == lab
        top = 2 * np.sum(np.logical_and(vol1l, vol2l))
        bottom = np.sum(vol1l) + np.sum(vol2l)
        bottom = np.maximum(bottom, np.finfo(float).eps)  # add epsilon.
        dicem[idx] = top / bottom

    if nargout == 1:
        return dicem
    else:
        return (dicem, labels)

def NJD(displacement):

    D_y = (displacement[1:,:-1,:-1,:] - displacement[:-1,:-1,:-1,:])
    D_x = (displacement[:-1,1:,:-1,:] - displacement[:-1,:-1,:-1,:])
    D_z = (displacement[:-1,:-1,1:,:] - displacement[:-1,:-1,:-1,:])

    D1 = (D_x[...,0]+1)*( (D_y[...,1]+1)*(D_z[...,2]+1) - D_z[...,1]*D_y[...,2])
    D2 = (D_x[...,1])*(D_y[...,0]*(D_z[...,2]+1) - D_y[...,2]*D_x[...,0])
    D3 = (D_x[...,2])*(D_y[...,0]*D_z[...,1] - (D_y[...,1]+1)*D_z[...,0])
    Ja_value = D1-D2+D3

    return np.sum(Ja_value<0)

def main():
    batch_size = 1
    Weights = [1.0, 1.0]
    img_size = (144, 192, 160)
    save_dir = 'CorrMLP_ncc_{}_diffusion_{}/'.format(Weights[0], Weights[1])
    if not os.path.exists('experiments/'+save_dir):
        os.makedirs('experiments/'+save_dir)
    if not os.path.exists('logs/'+save_dir):
        os.makedirs('logs/'+save_dir)
    sys.stdout = Logger('logs/'+save_dir)

    train_dir = r'/home/mh/PythonCodes/OASIS_L2R_2021_task03/All/'
    val_dir = '/home/mh/PythonCodes/OASIS_L2R_2021_task03/Test/'

    lr = 1e-4 # learning rate
    epoch_start = 0
    max_epoch = 500 #max traning epoch
    cont_training = False #if continue training

    '''
    Initialize model
    '''

    model = CorrMLP()
    model.cuda()

    # transfer model
    SpatialTransformer = SpatialTransformer_block(mode='nearest')
    SpatialTransformer.cuda()
    SpatialTransformer.eval()

    '''
    Initialize training
    '''
    train_composed = transforms.Compose([trans.NumpyType((np.float32, np.float32))])
    val_composed = transforms.Compose([trans.NumpyType((np.float32, np.int16))])

    train_set = datasets.OASISBrainDataset(glob.glob(train_dir + '*.pkl'), transforms=train_composed)
    val_set = datasets.OASISBrainInferDataset(glob.glob(val_dir + '*.pkl'), transforms=val_composed)
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_set, batch_size=1, shuffle=False, num_workers=4, pin_memory=True, drop_last=True)

    optimizer = optim.Adam(model.parameters(), lr=lr)
    Losses = [losses.NCC(win=9).loss, losses.Grad('l2').loss]

    best_dsc = 0
    writer = SummaryWriter(log_dir='logs/'+save_dir)

    for epoch in range(epoch_start, max_epoch):
        print('Training Starts')
        '''
        Training
        '''
        loss_all = utils.AverageMeter()
        idx = 0
        for data in train_loader:
            idx += 1
            model.train()
            # adjust_learning_rate(optimizer, epoch, max_epoch, lr)
            data = [t.cuda() for t in data]
            x = data[0][:, :, 8:152, :, 32:192]
            y = data[1][:, :, 8:152, :, 32:192]
            # x_in = torch.cat((x,y), dim=1)
            output, flow = model(y, x)
            loss_ncc = Losses[0](y, output) * Weights[0]
            loss_reg = Losses[1](np.zeros((1)), flow) * Weights[1]
            loss = loss_ncc + loss_reg
            loss_all.update(loss.item(), y.numel())
            # compute gradient and do SGD step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            print('Iter {} of {} loss {:.4f}, Img Sim: {:.6f}, Reg: {:.6f}'.format(idx, len(train_loader),
                                                                                                loss.item(),
                                                                                                loss_ncc.item(),
                                                                                                loss_reg.item()))
        writer.add_scalar('Loss/train', loss_all.avg, epoch)
        print('Epoch {} loss {:.4f}'.format(epoch, loss_all.avg))
        '''
        Validation
        '''
        eval_dsc = utils.AverageMeter()
        with torch.no_grad():
            for data in val_loader:
                model.eval()
                data = [t.cuda() for t in data]
                x = data[0][:, :, 8:152, :, 32:192]
                y = data[1][:, :, 8:152, :, 32:192]
                x_seg = data[2][:, :, 8:152, :, 32:192]
                y_seg = data[3][:, :, 8:152, :, 32:192]
                # x_in = torch.cat((x, y), dim=1)
                grid_img = mk_grid_img(8, 1, img_size)
                output = model(y, x)
                def_out = SpatialTransformer(x_seg.cuda().float(), output[1].cuda())
                def_grid = SpatialTransformer(grid_img.float(), output[1].cuda())
                dsc = utils.dice_val_VOI(def_out.long(), y_seg.long())
                eval_dsc.update(dsc.item(), x.size(0))
                print(eval_dsc.avg)
        best_dsc = max(eval_dsc.avg, best_dsc)
        save_checkpoint({
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'best_dsc': best_dsc,
            'optimizer': optimizer.state_dict(),
        }, save_dir='experiments/'+save_dir, filename='dsc{:.4f}.pth.tar'.format(eval_dsc.avg))
        writer.add_scalar('DSC/validate', eval_dsc.avg, epoch)
        plt.switch_backend('agg')
        pred_fig = comput_fig(def_out)
        grid_fig = comput_fig(def_grid)
        x_fig = comput_fig(x_seg)
        tar_fig = comput_fig(y_seg)
        writer.add_figure('Grid', grid_fig, epoch)
        plt.close(grid_fig)
        writer.add_figure('input', x_fig, epoch)
        plt.close(x_fig)
        writer.add_figure('ground truth', tar_fig, epoch)
        plt.close(tar_fig)
        writer.add_figure('prediction', pred_fig, epoch)
        plt.close(pred_fig)
        loss_all.reset()
        del def_out, def_grid, grid_img, output
    writer.close()

def comput_fig(img):
    img = img.detach().cpu().numpy()[0, 0, 48:64, :, :]
    fig = plt.figure(figsize=(12,12), dpi=180)
    for i in range(img.shape[0]):
        plt.subplot(4, 4, i + 1)
        plt.axis('off')
        plt.imshow(img[i, :, :], cmap='gray')
    fig.subplots_adjust(wspace=0, hspace=0)
    return fig

def adjust_learning_rate(optimizer, epoch, MAX_EPOCHES, INIT_LR, power=0.9):
    for param_group in optimizer.param_groups:
        param_group['lr'] = round(INIT_LR * np.power( 1 - (epoch) / MAX_EPOCHES ,power),8)

def mk_grid_img(grid_step, line_thickness=1, grid_sz=(160, 192, 224)):
    grid_img = np.zeros(grid_sz)
    for j in range(0, grid_img.shape[1], grid_step):
        grid_img[:, j+line_thickness-1, :] = 1
    for i in range(0, grid_img.shape[2], grid_step):
        grid_img[:, :, i+line_thickness-1] = 1
    grid_img = grid_img[None, None, ...]
    grid_img = torch.from_numpy(grid_img).cuda()
    return grid_img

def save_checkpoint(state, save_dir='models', filename='checkpoint.pth.tar', max_model_num=8):
    torch.save(state, save_dir+filename)
    model_lists = natsorted(glob.glob(save_dir + '*'))
    while len(model_lists) > max_model_num:
        os.remove(model_lists[0])
        model_lists = natsorted(glob.glob(save_dir + '*'))

if __name__ == '__main__':
    '''
    GPU configuration
    '''
    GPU_iden = 0
    GPU_num = torch.cuda.device_count()
    print('Number of GPU: ' + str(GPU_num))
    for GPU_idx in range(GPU_num):
        GPU_name = torch.cuda.get_device_name(GPU_idx)
        print('     GPU #' + str(GPU_idx) + ': ' + GPU_name)
    torch.cuda.set_device(GPU_iden)
    GPU_avai = torch.cuda.is_available()
    print('Currently using: ' + torch.cuda.get_device_name(GPU_iden))
    print('If the GPU is available? ' + str(GPU_avai))
    torch.manual_seed(0)
    main()
MungoMeng commented 1 month ago

Here are two examples of the training images I used in our experiments. Could you please manually check whether there are significant differences? Image.zip