cwmok / C2FViT

This is the official Pytorch implementation of "Affine Medical Image Registration with Coarse-to-Fine Vision Transformer" (CVPR 2022), written by Tony C. W. Mok and Albert C. S. Chung.
MIT License
131 stars 3 forks source link

About the AffineCOMTransform #16

Closed xiaorugao999 closed 3 months ago

xiaorugao999 commented 5 months ago

I want to express my sincere appreciation for the excellent work you've done so far. Before we move forward, I have several inquiries I'd like to make:

Firstly, regarding the constraints on the translation parameter range, I could not find any reference to it in the code. Could you clarify how the (-50%-50%) range, as mentioned in the paper, is implemented? Next, about the loss function, it appears that only the final level has been utilized in the groundwork. I'd like to know if there is a need to involve the initial two levels My final point touches upon the position embedding. I see that multiple methods have been implemented in the model's file. Can you provide guidance on which method would be the most appropriate to choose?

Thank you in advance for your responses and continuous assistance.

cwmok commented 5 months ago

Hi @xiaorugao999,

Firstly, regarding the constraints on the translation parameter range, I could not find any reference to it in the code. Could you clarify how the (-50%-50%) range, as mentioned in the paper, is implemented?

It is implemented in the model. As shown in the code, we use Tanh() activation to restrict the output affine values must be within [-1, 1], which equivalent to (-50%-50%) for translation (the grid is normalized to [-1, 1] in Pytorch).

Next, about the loss function, it appears that only the final level has been utilized in the groundwork. I'd like to know if there is a need to involve the initial two levels

Good point! I did try to combine the output of the initial two levels, but the results are not as good as using only the final output.

My final point touches upon the position embedding. I see that multiple methods have been implemented in the model's file. Can you provide guidance on which method would be the most appropriate to choose?

After submitting this work to CVPR, I tried different position embedding and found that the position embedding didn't do much in the model. C2F_ViT_stage_peg will be the most effective one, but the difference is not statistically significant. Surprisingly, the model also works without position embedding.

xiaorugao999 commented 5 months ago

Dear @cwmok,

First and foremost, I would like to express my heartfelt appreciation for your swift and effective assistance in resolving my previous inquiries. Your insightful responses have been tremendously beneficial to my understanding and progress.

Moving forward, I still have a couple of minor questions I was hoping you could help clarify:

  1. When training on my dataset, should I maintain consistent origin points and spacing to a pair of Fix and Moving images?

  2. Regarding the combination of transformation matrices, I'm somewhat puzzled by the code sequence: output_rigid_m = torch.mm(to_center_matrix, torch.mm(self.rotation_m, torch.mm(reversed_to_center_matrix,self.translation_m))). This appears as though the procedure undertakes translation before rotation. Why is it not implemented as: output_rigid_m = torch.mm(reversed_to_center_matrix, torch.mm(self.translation_m, torch.mm(to_center_matrix, self.rotation_m)))?

I hope these queries aren't too bothersome. Your expertise and guidance are deeply appreciated and highly valued.

I eagerly look forward to your reply.

Once again, thank you for your support and assistance.

Warm Regards

cwmok commented 5 months ago

Hi @xiaorugao999,

When training on my dataset, should I maintain consistent origin points and spacing to a pair of Fix and Moving images?

It is not necessary but a good practice to do so.

Regarding the combination of transformation matrices, I'm somewhat puzzled by the code sequence: output_rigid_m = torch.mm(to_center_matrix, torch.mm(self.rotation_m, torch.mm(reversed_to_center_matrix,self.translation_m))). This appears as though the procedure undertakes translation before rotation. Why is it not implemented as: output_rigid_m = torch.mm(reversed_to_center_matrix, torch.mm(self.translation_m, torch.mm(to_center_matrix, self.rotation_m)))?

For my understanding, the execution order of the transformation is going from left to right, but not right to left. That's, in output_rigid_m = torch.mm(to_center_matrix, torch.mm(self.rotation_m, torch.mm(reversed_to_center_matrix,self.translation_m))), the execution order is to_center_matrix -> rotation_m -> reversed_to_center_matrix -> translation_m. You can test it out with toy code.

xiaorugao999 commented 5 months ago

Dear @cwmok,

I would like to extend my heartfelt thanks for your prompt and detailed answers to my previous questions. Your responses have been instrumental in resolving my long-standing issues and have significantly improved my understanding of this project.

I still have one relatively minor query I wish to explore further. During the process of translating to the rotation center, I noticed that we set 'to_center_matrix[0, 3] = center_mass_x'. My question concerns why we don't set 'to_center_matrix[0, 3] = - center_mass_x'.

I eagerly await your response and would like to reiterate my gratitude for your previous assistance. Best regards

cwmok commented 5 months ago

Hi @xiaorugao999,

Sorry for the late reply. I had the same thought as you when I was developing the code. I managed to figure it out by doing a toy experiment, i.e., try both to_center_matrix[0, 3] = center_mass_x and to_center_matrix[0, 3] = - center_mass_x, to see which one show the expected behaviour in initial alignment.

I found that 'to_center_matrix[0, 3] = center_mass_x' will work, but not for the 'to_center_matrix[0, 3] = - center_mass_x'. You can always play around with it by setting up new toy experiments.

xiaorugao999 commented 4 months ago

Hi @cwmok, Thank you for your response. I have learned a lot from your excellent work. I have a small question, when training with my data, sometimes I observe that the transformation parameters predicted on the training set vary during the training process. However, on the validation and test sets, each different case seems to produce identical transformation parameters. What could be causing this issue? Thank you very much.

cwmok commented 4 months ago

@xiaorugao999

I am not sure about it due to very limited information you provided. I have never encountered any similar issues as yours before. My wild guess is that the preprocessing/dataloader for the validation set may be the issue. Double-check the validation data (save the validation data as .nii.gz from the Pytorch tensor right before it jumps in the network.)

Hope the suggestion helps.