junyuchen245 / TransMorph_Transformer_for_Medical_Image_Registration

TransMorph: Transformer for Unsupervised Medical Image Registration (PyTorch)
MIT License
432 stars 71 forks source link

Deformation field close to zero. #25

Closed kvttt closed 2 years ago

kvttt commented 2 years ago

Hi Junyu,

Sorry to bother you again. Last time you suggested that I look at the diffeomorphic variants. Lately, I have been specifically looking at VoxelMorph-diff among the baseline models you included. However, I have observed that after the training converges, the resulted deformation fields are very close to zero, i.e. the intensity at each voxel falls under the range of around 0.001 to 0.02. When I visualize the deformation field in ITK-SNAP (see attached image), the grid lines seem straight which is drastically different from the deformation fields produced by non-diffeomorphic variants. I am wondering if this is normal.

Screen Shot 2022-06-07 at 6 51 33 AM

At first, I suspected that I am saving the stationary velocity field (SVF) instead of the deformation field but that is not the case. I also suspected that there is something wrong with the "scaling and squaring" integrating step. However, the deformation field is still close to zero even after I replace the for-loop with the VecInt function here: https://github.com/voxelmorph/voxelmorph/blob/a746f77098962da1be9e6a03dacc3ef6c90d5244/voxelmorph/torch/layers.py#L51-L68.

After all these attempts to understand what causes such behavior of the model, I assume that the diffeomorphic variants such as VoxelMorph-diff tend to produce close to zero deformation fields. I would really appreciate it if you could provide some insight on this or confirm my observation.

Again, I enjoyed reading your paper and your code!

junyuchen245 commented 2 years ago

Hi @kvttt ,

Apologies for the delayed response. I borrowed the VoxelMorph-diff implementation from https://github.com/uncbiag/easyreg. The network output deformation field is actually normalized: https://github.com/junyuchen245/TransMorph_Transformer_for_Medical_Image_Registration/blob/47517d0cc5b7b8b56c77b6b3efa60d59abba3432/IXI/Baseline_registration_methods/VoxelMorph-diff/models.py#L493-L494 To visualize the deformation before normalization, you can output the field from this line: https://github.com/junyuchen245/TransMorph_Transformer_for_Medical_Image_Registration/blob/47517d0cc5b7b8b56c77b6b3efa60d59abba3432/IXI/Baseline_registration_methods/VoxelMorph-diff/models.py#L493

Junyu

kvttt commented 2 years ago

Thanks for the clear explanation! In addition, should I also use the unnormalized deformation when calculating losses (e.g., Grad3d loss)?

junyuchen245 commented 2 years ago

Hi @kvttt ,

You don't need to. All the loss functions were calculated using the normalized field (the same training framework in https://github.com/uncbiag/easyreg). Although it is feasible to use the unnormalized field for computing losses, you just have to tune the hyperparameters. So I don't recommend you to do it.

Junyu

kvttt commented 2 years ago

Hi @junyuchen245,

Thank you for the quick reply. I just tried visualizing the un-normalized deformation and it works! As for my question regarding calculating loss on the un-normalized deformations, I think adjusting hyperparameters would definitely work. The trick I am currently using is to compensate for the difference between the normalized and unnormalized deformations by utilizing the loss_mult argument here:

https://github.com/junyuchen245/TransMorph_Transformer_for_Medical_Image_Registration/blob/dfa24a47a564a000aa9b4eea95a6e83a24568359/TransMorph/losses.py#L546-L548

And I think both approaches would essentially yield the same result. Again, thank you so much for the explanation!