xzluo97 / X-metric

X-Metric: An N-Dimensional Information-Theoretic Framework for Groupwise Registration and Deep Combined Computing (TPAMI 2023)
Apache License 2.0
8 stars 2 forks source link

Group-wise Registration of 3D T1 Brain MRI scans #1

Closed BailiangJ closed 11 months ago

BailiangJ commented 1 year ago

Hi @xzluo97 ,

Thank you for the great work and kindly sharing the code.

I am now trying to run the group-wise registration (FFD or SVF) on a set of three T1 brain 3D MRIs (.nii.gz file), and save the final displacement field.

If I want to use XCoRegUN and XCoRegGT (both without segmentation), how should I start?

I am a bit lost in the code and don't know where to modify, could you give me some hints/guidances?

Thanks a lot!

xzluo97 commented 1 year ago

Hi Bailiang,

Thanks for your interest in our work!

To modify the code for your own dataset. You may follow the code in ./core/trainers/Brainweb/BrainWebGroupRegTrainer.py to create a new trainer module. Particularly, you may first create a PyTorch Dataset module similar to ./core/data/data_providers/Brainweb/BrainwebImageProvider.py to load your data. Then, create a model file similar to ./core/models/Brainweb/BrainWebRegModel.py, which instantiate the registration method you want to use, and set the evaluation metrics. Finally, back to the trainer module to load your data and initialize the model and optimizer. Training can be done with

for j in range(model.reg.num_reg_levels):
    model.reg.activate_params([j])
    opt = opts[j]
    for step in range(0 if j == 0 else cum_steps[j - 1], cum_steps[j]):

        opt.zero_grad()

        warped_images = model.reg()

        loss = model.reg.loss_function(warped_images)
        loss.backward()

        opt.step()

You may also check ./core/configs/ConfigBrainWeb.py to see what arguments you may need to configure the registration. For example, to use FFD you may need to set --ffd_spacing and --zero_avg_flow, while for SVF your may set --zero_avg_vec. To use XCoRegUN/XCoRegGT, you may also need to tune the parameter --num_classes. Note that for XCoRegGT you need also to load the anatomical labels for the corresponding input images.

If you have any further questions, please feel free to contact me.

Good luck!

BailiangJ commented 1 year ago

Hi @xzluo97,

Thank you for the quick reply.

Yes, I am adopting the code from the BrainWeb example.

Regarding the parameters, I have a few questions.

  1. From _verify_hyerparameters, I can see transform_type is a multiple choice from ['TRA', 'RIG', 'AFF', 'FFD', 'DDF', 'SVF']. Could it be ['FFD', 'DDF']? But I can see that 'FFD' has its own pre_ffd_level which equals the length of pred_ffd_spacing if using isotropic spacing across dimension, while 'DDF' is a full size vector field. From the code I also see that if I use ['FFD', 'SVF'], it's not meaning the velocity field is parametrized by B-Spline, is that correct?

  2. What are the meanings of num_reg_levels max_levels and num_res, do they related to the usual definition of pyramid levels in conventional method? -> one level down it will be half size downsampled. Could you also explain the difference between reg_level and res_level ?

  3. If I understand correctly from the paper, using XCoRegUN, it will do the clustering with num_classes for the appearance model. And if using XCoRegGT, it will estimate the appearance model based on the input segmentation, then should I also specify num_classes for XCoRegGT?

The code is very well structured and the naming of the variables is informative, it helps a lot when I read them and understand them, maybe some explainatory comments would be a great addition.

Thanks a lot for your help and patience with my possible upcoming questions.:)

Cheers, Bailiang

xzluo97 commented 1 year ago

Hi Bailiang,

Thank you for the questions. I will answer them point-by-point as follows:

  1. Yes the transform_type can be ['FFD', 'DDF'], which means you have these two types of transformations composed together for the registration. Besides, by using the method self.activate_params(levels) you can optimize them individually. Also, you are right that using transform_type=['FFD', 'SVF'] doesn't mean the velocity field is parameterized by B-Spline. However, you can easily set up for it with the current code.
  2. num_reg_levels means the total number of transformations composed together for registration. However, max_reg_level means the current highest registration level (higher means the registration is finer), which can be altered by the method self.activate_params(levels) to tell the model not to compose any higher-level transformation that are not optimized when predicting the total transformation. On the other hand, num_res is actually related to pyramid levels you mentioned. For example, if you are using transform_type=['AFF', 'FFD'], then setting num_res=[3,3,3,1,1,1] means for AFF transformation you will use three resolutions across all dimensions of the images to compute the similarity measure, while for FFD you will use a single resolution. For each resolution, the original images will be downsampled by a factor of 2. More specifically, when using num_res=[3,3,3] for AFF and steps=300, the images will first be downsampled by a factor of 4 for optimization during step=0:100, then a factor of 2 during step=100:200, and finally no downsampling during step=200:300.
  3. Yes you are right. For XCoRegGT, you also need to specify num_classes to be the same as the number of label classes in your input segmentations.

Cheers, Xinzhe

BailiangJ commented 1 year ago

Hi @xzluo97 ,

Thank you for your answers. Sorry to bother you again, I still have the following questions. I am only considering the flow field transformation.

  1. For num_res, it is only rescaling warped images to different resolution and compute the loss at different resolutions, but it's not generating the corresponding flow field for the different resolutions, it that correct? To put this in another word, it is only computing the full size flow field throughout the registration.

  2. For num_ffd_levels and pred_ffd_spacing, it's actually similar to question 1, so if I have three sets of isotropic spaincg [[1],[3],[5]], the transformation would be 3 full size ffd flow field and the final predict_flow would be the composition of these three?

  3. What is the difference between parameters label and mask? I can see that label will also be taken as weight in SVCD. So is mask the mask of ROI? e.g. the mask of the brain area.

  4. Are images tensors of shape (B, num_subjects, H,W,D) and labels -> one-hot of shape (B, num_subjects, num_classes, H,W,D)?

Thanks a lot for your help!

Cheers, Bailiang

BailiangJ commented 1 year ago

I am attaching my modification of the model, I would really appreaciate it if you could check for me if you have time and see if something is missing due to my understanding (they are running good, except when I use XCoRegGT it will have cudaOOM, but it might be my num_classes is too big).

class MyRegModel(nn.Module):
    def __init__(self,
                 dimension=3,
                 img_size=(160,160,192),
                 eps=1e-8,
                 **kwargs):
        '''
        num_classes
        num_bins
        sample_rate
        kernel_sigma
        eps
        momentum
        min_prob
        '''
        super(MyRegModel, self).__init__()
        self.dimension = dimension
        self.img_size = img_size
        self.eps = eps
        self.kwargs = kwargs
        self.model_type = self.kwargs.pop('model_type', None)
        assert self.model_type in ['XCoRegUn', 'XCoRegGT', 'GMM', 'APE', 'CTE']
        if self.model_type == 'XCoRegUn':
            Model = XCoRegUnRegModel.XCoRegUnRegModel
        elif self.model_type == 'XCoRegGT':
            Model = XCoRegGTRegModel.XCoRegGTRegModel
        elif self.model_type == 'GMM':
            Model = GMMRegModel.GMMRegModel
        elif self.model_type == 'APE':
            Model = APERegModel.APERegModel
        elif self.model_type == 'CTE':
            Model = CTERegModel.CTERegModel
        else:
            raise NotImplementedError
        self.reg = Model(self.dimension, self.img_size, eps=self.eps, num_subjects=3, **self.kwargs)

        self.num_subjects = self.reg.num_subjects
        self.mask_sigma = self.kwargs.pop('mask_sigma', -1)
        self.prior_sigma = self.kwargs.pop('prior_sigma', -1)

    def init_model_params(self, images, label=None):
        B = images.shape[0]
        assert images.shape[1] == self.num_subjects

        self.reg.init_reg_params(images=images)

        if label is None:
            prior = torch.full((B, self.reg.num_classes, *self.reg.img_size),
                                fill_value=1 / self.reg.num_classes, device=images.device, dtype=images.dtype)
            mask = torch.ones(B, 1, *self.reg.img_size, dtype=images.dtype, device=images.device)
        else:
            if self.mask_sigma == -1:
                mask = torch.ones(B, 1, *self.reg.img_size, dtype=images.dtype, device=images.device)
            else:
                mask = self._spatial_filter(label[:, 1:].sum(dim=1, keepdim=True),
                                            utils.gauss_kernel1d(self.mask_sigma)).gt(self.eps).to(self.images.dtype)

            if self.prior_sigma == -1:
                prior = torch.full((B, self.reg.num_classes, *self.reg.img_size),
                                    fill_value=1 / self.reg.num_classes, device=images.device, dtype=images.dtype)

            else:
                prior = utils.compute_normalized_prob(self._spatial_filter(label,
                                                                           utils.gauss_kernel1d(self.prior_sigma)),
                                                      dim=1)
        self.reg.mask = mask
        self.reg.prior = prior

        labels = torch.unbind(label, dim=1)

        if self.model_type == 'XCoRegUn':
            self.reg.init_app_params()
        if self.model_type == 'XCoRegGT':
            self.reg.init_app_params(labels=labels)
        if self.model_type == 'GMM':
            self.reg.init_gmm_params()

For the RegModel, I am just removing the part where you adding noise and generating ground truth misalignment in the BrainWeb Model.

class IterGroupRegTrainer(Trainer):
    def train(self, images, label, device='cuda:0', **kwargs):
        steps = kwargs.pop('steps', [50])
        assert len(steps) == self.net.reg.num_reg_levels
        cum_steps = np.cumsum(steps)
        if isinstance(self.lr, float):
            lr = [self.lr] * self.net.reg.num_reg_levels
        elif isinstance(self.lr, (tuple, list)):
            if len(self.lr) == 1:
                lr = list(self.lr) * self.net.reg.num_reg_levels
            else:
                assert len(self.lr) == self.net.reg.num_reg_levels
                lr = self.lr
        else:
            raise NotImplementedError

        self.writer = self._get_writer(self.save_path)

        # MyRegModel
        model = self.net.to(device)
        # model.reg -> XCoRegUn/XCoRegGT/APE/CTE
        # images =
        # label = # one-hot

        model.init_model_params(images, label)
        opts = [self._get_optimizer([model.reg.params[model.reg.reg_level_type[j]]], lr=lr[j])
                for j in range(model.reg.num_reg_levels)]

        for j in range(model.reg.num_reg_levels):
            print('reg_level',j)
            model.reg.activate_params([j])
            opt = opts[j]
            for step in range(0 if j == 0 else cum_steps[j - 1], cum_steps[j]):
                print('step',step)
                opt.zero_grad()

                warped_images = model.reg()

                loss = model.reg.loss_function(warped_images)
                loss.backward()

                opt.step()

        # model.reg.num_reg_levels = 1
        reg_flows = model.reg.predict_flows()
        # save the reg_flows
        # the reg_flows is mapping each subjects to the common space

if __name__ == '__main__':
    # configurations and dataloading
    images = ... # (B=1, num_subjects=3, 160,160,192)
    labels = ... # (B=1, num_subjects=3, num_classes=20,160,160,192)

    net = MyRegModel(
        dimension=3,
        img_size=(160,160,192),
        num_classes=4,
        model_type='XCoRegGT', # 'XCoRegGT', 'GMM', 'APE', 'CTE'
        num_bins=64,
        sample_rate=0.1,
        kernel_sigma=1,
        mask_sigma=-1,
        prior_sigma=-1,
        alpha=1,
        transform_type=['SVF'], # 'DDF', 'FFD'
        # pred_ffd_spacing=,
        # pred_ffd_iso=,
        group2ref=False,
        # zero_avg_flow=True,
        zero_avg_vec=True,
        norm_img=False,
        eps=1e-8,
    )

    trainer = IterGroupRegTrainer(net,
                                  verbose=0,
                                  save_path=save_path,
                                  optimizer_name='Adam',
                                  learning_rate=0.1,
                                  weight_decay=0.0,
                                  scheduler_name='CyclicLR',# None, 'OneCyclicLR'
                                  max_lr=1e-4,
                                  base_lr=1e-5,
                                  logger=logger,)

    trainer.train(
        images=images,
        label=label,
        device=device,
        steps=[10,],
        num_workers=8,
    )

Here is my running script. Do you have any suggestions for the parameters setting or which parameters should I pay more attention to when I want to get more accurate results? I am registering 3 brain T1 MRI (taken at different time) of the same patient.

Also, there is one tiny bug I have found when I run the code, since torch.unbind is removing the dimension we specify, so in GroupRegModel.forward() the input image to the SpatialTransformer will lack one extra dimension to have grid_sample properly running. For example, after torch.unbind, it will give a tuple of tensors of shape (B, H, W, D) but SpatialTransformer requires image of shape (B,1,H,W,D). :)

Thanks a lot for your help!

Cheers, Bailiang

xzluo97 commented 1 year ago

Hi Bailiang,

Thank you for the questions. For the first 4 questions you posted, I will answer them point-by-point as follows:

  1. Yes, that's true. num_res only downsamples the warped images but not the transformation fields.
  2. Yes, the FFDs are full-size and are composed together.
  3. The mask parameter is to confine the region of ROI samples used for density estimation. This is typically used for images with complex background. For brain images you may not need to set up for it because the background is much simpler.
  4. The images are of shape (B, num_subjects, 1, H,W,D), and labels of shape (B, num_subjects, num_classes, H,W,D).

Cheers, Xinzhe

xzluo97 commented 1 year ago

For the attached code for your own data, glad to see you run them well. But you may not need to set the parameter scheduler_name as it may alter the step size unexpectedly. And since you run the registration with SVF transformations, you may need to tune the parameter alpha for velocity field regularization.

Cheers, Xinzhe

BailiangJ commented 1 year ago

@xzluo97 Thank you so much for all the information,

the images should of shape (B, num_subjects, 1, H,W,D)

I see, that's the reason why I had problem with the SpatialTransformer.

Yes, I have found out that BendingEnergy on the velocity field is giving staircase artifacts in the boundary region of the resampled images. So I also added the MembraneEnergy to GroupRegModel._get_regularization.

def _get_regularization(self):
        r = 0
        for j in self.activated_reg_levels:
            if self.reg_level_type[j] not in ['TRA', 'AFF', 'RIG']:
                r += self.bending_energy(self.params[self.reg_level_type[j]]) * self.group_num
                r += self.membrane_energy(self.params[self.reg_level_type[j]]) * self.group_num

        return r

For the LR scheduler, would it help if I use normal exponential decay or it is fine to use a constant LR? (Or I just try it out by myself to see the results.)

Thanks again for your help. I think our conversations here would also be beneficial to other researchers who want to run the algorithms!:)

Cheers, Bailiang

xzluo97 commented 1 year ago

Hi Bailiang,

I have only used the constant LR in my experiments and didn't try setting the LR scheduler, so I'm afraid I could not give you more advice on this. But I think using LR scheduler may help prevent local optima, depending on the landscape of the similarity measure w.r.t. the misalignment.

Glad to see our conversation helps you. Good luck to your experiments!

Best, Xinzhe