alanqrwang / brainmorph

A Foundational Keypoint Model for Robust and Flexible Brain MRI Registration
MIT License
29 stars 4 forks source link

Converting deformation field to voxel index displacement field and cropping to input image size #1

Closed junyuchen245 closed 4 months ago

junyuchen245 commented 4 months ago

Hi @alanqrwang ,

Thanks for the awesome work on BrainMorph! I started playing around with it using the LUMIR data from the Learn2Reg challenge and noticed that the Dice and TRE results aren't looking great. Could I be converting the normalized deformation field generated by BrainMorph to a displacement field wrong for the challenge submissions? Any advice would be super helpful.

https://github.com/JHU-MedImage-Reg/LUMIR_L2R/blob/44a869d5d8adb7b55329a267558cd9370082528a/BrainMorph/infer_BrainMorph.py

alanqrwang commented 4 months ago

Hi Junyu,

Will download the data, debug on your script, and get back to you asap.

junyuchen245 commented 4 months ago

Thanks for the prompt reply, Alan. The label maps and landmarks in the dataset are kept confidential as LUMIR is focused on unsupervised registration, and the leaderboard for the challenge will not be public until later this week. I've been testing internally and considering making BrainMorph a baseline method, which is why I raised an issue here and wondered whether I erred in converting the deformation fields to voxel indexed displacements. Notably, while the dice scores for SynthMorph exceeded 0.7, they were only around 0.5 for BrainMorph after the conversion.

If you're interested in testing BrainMorph on the LUMIR data now, please use the label maps provided here. These maps are for sanity checks by participants before they submit their validation results and correspond to the label maps for five images in the training data. Otherwise, I'll notify you once the leaderboard is public so you can test my script.

alanqrwang commented 4 months ago

Hi Junyu,

Can you pull the latest commit and run your script again? I pushed this commit: https://github.com/alanqrwang/brainmorph/commit/ef00044db3aff2be983c6e519f8a6a51d2aaaf6d

I think the issue was that my script had a bug that skipped performing the registration if all the outputs were present on disk, and I suspect it just kept using the first registration for all of your validation data (Maybe if you look at Dice scores across the validation set, your first Dice score is higher than the rest?)

I also found a small issue with the 'H' variant models that I will look into. We didn't report these variants in the paper but I released them anyway; I would recommend switching to 'L' variant for now (I think the difference between them is not that much anyways).

Sorry for the inconvenience and let me know if this solves the issue. Here's the code that gives me reasonable results on my end:

'''
BrainMorph for LUMIR at Learn2Reg 2024
Author: Junyu Chen
        Johns Hopkins University
        jchen245@jhmi.edu
Date: 05/28/2024
'''
import os, random, glob, sys
sys.path.insert(0, '/home/alw4013/LUMIR_L2R')
import subprocess

import matplotlib.pyplot as plt
import numpy as np
from TransMorph.data import datasets
from torch.utils.data import DataLoader
import torch, shutil
import torch.nn.functional as F
import nibabel as nib

from keymorph.utils import align_img, displacement2pytorchflow

def shell_command(command):
    print("RUNNING", command)
    subprocess.run(command, shell=True)

def standardize_flow(flow):
    flow = torch.from_numpy(flow[None, ])
    flow = flow[...,[2,1,0]]
    flow = flow.permute(0, 4, 1, 2, 3)  # Bring channels to second dimension
    shape = flow.shape[2:]

    # Scale normalized flow to pixel indices
    for i in range(3):
        flow[:, i, ...] = (flow[:, i, ...] + 1) / 2 * (shape[i] - 1)

    # Create an image grid for the target size
    vectors = [torch.arange(0, s) for s in shape]
    grids = torch.meshgrid(vectors, indexing='ij')
    grid = torch.stack(grids, dim=0).unsqueeze(0).to(flow.device, dtype=torch.float32)

    # Calculate displacements from the image grid
    disp = flow - grid
    return disp.cpu().detach().numpy()[0]

def nib_load(file_name):
    if not os.path.exists(file_name):
        return np.array([1])

    proxy = nib.load(file_name)
    data = proxy.get_fdata()
    proxy.uncache()
    return data

def save_nii(img, file_name, pix_dim=[1., 1., 1.]):
    x_nib = nib.Nifti1Image(img, np.eye(4))
    x_nib.header.get_xyzt_units()
    x_nib.header['pixdim'][1:4] = pix_dim
    x_nib.to_filename('{}.nii.gz'.format(file_name))

def main():
    # val_dir = 'F:/Junyu/DATA/LUMIR/'
    val_dir = '/midtier/sablab/scratch/data/Learn2Reg_2024/'
    brainmorph_in_dir = '/midtier/sablab/scratch/alw4013/brainmorph/register_input'
    brainmorph_out_dir = '/midtier/sablab/scratch/alw4013/brainmorph/register_output'
    lumir_out_dir = '/midtier/sablab/scratch/alw4013/brainmorph/LUMIR_outputs/'
    if not os.path.exists(lumir_out_dir):
        os.makedirs(lumir_out_dir)

    val_set = datasets.L2RLUMIRJSONDataset(base_dir=val_dir, json_path=val_dir + 'LUMIR_dataset.json',
                                           stage='validation')
    val_loader = DataLoader(val_set, batch_size=1, shuffle=False, num_workers=4, pin_memory=True, drop_last=True)
    val_files = val_set.imgs

    '''
    Validation
    '''
    with torch.no_grad():
        for i, data in enumerate(val_loader):
            mv_id = val_files[i]['moving'].split('_')[-2]
            fx_id = val_files[i]['fixed'].split('_')[-2]
            if os.path.exists(brainmorph_out_dir):
                shutil.rmtree(brainmorph_out_dir)
            x_image = data[0]
            y_image = data[1]
            x = x_image.squeeze(0).squeeze(0).detach().cpu().numpy()
            y = y_image.squeeze(0).squeeze(0).detach().cpu().numpy()
            print('start registration: {}'.format(i))
            x_nib = nib.Nifti1Image(x, np.eye(4))
            x_nib.header.get_xyzt_units()
            x_nib.header['pixdim'][1:4] = [1., 1., 1.]
            x_nib.to_filename(f'{brainmorph_in_dir}/x.nii.gz')

            y_nib = nib.Nifti1Image(y, np.eye(4))
            y_nib.header.get_xyzt_units()
            y_nib.header['pixdim'][1:4] = [1., 1., 1.]
            y_nib.to_filename(f'{brainmorph_in_dir}/y.nii.gz')

            os.chdir('/home/alw4013/brainmorph')
            shell_command(f'python /home/alw4013/brainmorph/scripts/register.py --num_keypoints 512 --variant L --weights_dir /midtier/sablab/scratch/alw4013/brainmorph/weights/keymorph_weights_256x256x256  --moving {brainmorph_in_dir}/x.nii.gz --fixed {brainmorph_in_dir}/y.nii.gz --list_of_aligns tps_0 --list_of_metrics mse --save_eval_to_disk --download --save_dir {brainmorph_out_dir} --visualize')

            flow = np.load(f'{brainmorph_out_dir}/0_0_fixed_moving/grid_0-fixed_0-moving-rot0-tps_0.npy')
            flow = standardize_flow(flow)
            flow = flow[:, (256-160)//2:(256-160)//2+160, (256-224)//2:(256-224)//2+224, (256-192)//2:(256-192)//2+192]
            save_nii(flow, lumir_out_dir + 'disp_{}_{}'.format(fx_id, mv_id))
            print('disp_{}_{}.nii.gz saved to {}'.format(fx_id, mv_id, lumir_out_dir))

            # Plot results
            myflow = torch.tensor(flow).float().permute(1, 2, 3, 0)[None]
            myflow = displacement2pytorchflow(myflow)
            x_moved = align_img(myflow, torch.tensor(x)[None, None])
            fig, axes = plt.subplots(1, 6, figsize=(18, 3))
            axes[0].imshow(x[80])
            axes[1].imshow(y[80])
            axes[2].imshow(x_moved[0, 0, 80])
            axes[3].imshow(flow[0, 80])
            axes[4].imshow(flow[1, 80])
            axes[5].imshow(flow[2, 80])
            plt.show()

def seedBasic(seed=2021):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)

def seedTorch(seed=2021):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

if __name__ == '__main__':
    '''
    GPU configuration
    '''
    DEFAULT_RANDOM_SEED = 12
    seedBasic(DEFAULT_RANDOM_SEED)
    seedTorch(DEFAULT_RANDOM_SEED)
    main()

Best, Alan

junyuchen245 commented 4 months ago

Thanks for looking into this, Alan. Unfortunately, I still haven't been able to achieve competitive results. I've attached two JSON files that I generated using the H and L variants with the script you provided. metrics_BrainMorph_H_512.json metrics_BrainMorph_L_512.json

I think I might have an explanation. It appears that the key points generated by BrainMorph tend to focus on deep brain or subcortical structures and place less emphasis on cortical structures. This observation is supported by the examples I generated using BrainMorph (see below), and also by the figure you posted in the repository here. It appears that if no key points are set on the cortical structures or outside the brain, the resulting deformation field interpolated using TPS would likely cause minimal changes in areas without key points. This could explain why the Dice scores, especially for cortical structures, were suboptimal. As you can see from the examples, there is a noticeable mismatch between the fixed and the deformed moving images. Do you think this is a reasonable explanation? If so, I’m considering not using BrainMorph as a baseline for the deformable image registration challenge, because I don’t want to undermine its strengths due to the non-linear registration results here, especially its capability to handle large misalignments.

Example 1

Example 2

alanqrwang commented 4 months ago

Hi Junyu,

Yes, I think your explanation is correct, and we made the same observations/explanation in the paper. If you look at our results, SynthMorph/ANTs outperform BrainMorph on nonlinear with low degree of initial misalignment, which seems to be your use case. The strengths of BrainMorph probably shine when you need a robust initial alignment, but for more detailed use cases BrainMorph probably isn't the best.

We actually spent a lot of time understanding why the keypoints are in subcortical regions. The best explanation right now is that the best model finds that these regions are most stable and lead to the best registration. Future work might explore how to handle cortical regions.

Hope this helps!

Best, Alan

junyuchen245 commented 4 months ago

These insights are very helpful! Thanks for the great discussion as well.

Junyu