junyuchen245 / TransMorph_Transformer_for_Medical_Image_Registration

TransMorph: Transformer for Unsupervised Medical Image Registration (PyTorch)
MIT License
456 stars 75 forks source link

TransMorph-Affine produces the same transformation matrix for all images during the validation #63

Closed brumomento closed 1 year ago

brumomento commented 1 year ago

During the validation process the network outputs the same transformation matrix (and all the images are transformed in the exact same way) for every image pair no matter what the images are. This happened when I tried to adapting the model to work with 2D images. I tried running the original TransMorph-affine with the provided sample data, and the outcome is the same: all images result in the same transformation matrix.

Is this supposed to happen? I checked the data handling and the images that go into the network are indeed different on every iteration. I also printed the scalar sum of the tensor after each forward pass in the network and all the values are identical in consecutive iterations.

What is interesting is that it doesn't happen during the training loop: the transformation matrices are similar, but not identical; and the tensor values are different for different images.

junyuchen245 commented 1 year ago

Hi @brumomento

Which version of PyTorch are you currently using? A while back, I encountered a convergence problem with versions above 1.9.1 on Windows machines (Linux seems to be fine), but I haven't had the chance to thoroughly investigate the cause. You might consider downgrading to an earlier version to see if that resolves the issue.

brumomento commented 1 year ago

Hello @junyuchen245

Thank you for answering. I am using 2.0.1+cu118 on a Windows. This might be the issue. I will test it on an older version and will keep you in the loop.

junyuchen245 commented 1 year ago

Sounds good. You might also want to tune the learning rate. You should be able to observe changes in the affine matrix across iterations.

brumomento commented 1 year ago

Hello @junyuchen245

I have tested it on both Linux and Pytorch 1.9.0, the issue remains. Could you print out the transformation matrix in the validation loop of your code with a proper dataaset to see what happens?

I am doing it like this: ct_aff, mat, inv_mats = model(x_in) print('Mat: ', mat)

junyuchen245 commented 1 year ago

Hi @brumomento

This is strange. I'll look into it once I get a chance. Thanks.

junyuchen245 commented 1 year ago

Hi @brumomento

The issue has been resolved, primarily due to the initialization and normalization of the features. I realize now that initializing the weights with all zeros was not optimal, as this can often cause convergence problems, leading to the model getting trapped in local minima and failing to update the weights. I have updated the models and scripts accordingly. Additionally, I've prepared a toy example using a subset of the IXI dataset (selected 53 images randomly, with 50 for training and 3 for validation). The updated scripts for effectively training the affine model can be found here. Thank you for pointing this out again. Please feel free to reach out if you have any additional questions.

Btw, I've included a log showing the model's training across several epochs. Here, you'll notice that the model now produces different matrices during each validation case. logfile.log

Junyu

brumomento commented 1 year ago

Hello @junyuchen245

I was trying to debug the code as well and found that indeed the weight initialisation was the issue, however was far from figuring out how to fix it. Thank you for updating the scripts!