cwmok / DIRAC

This is the official Pytorch implementation of "Unsupervised Deformable Image Registration with Absent Correspondences in Pre-operative and Post-Recurrence Brain Tumor MRI Scans" (MICCAI 2022), written by Tony C. W. Mok and Albert C. S. Chung.
MIT License
37 stars 1 forks source link

Mask loss ("OCC") cannot be backpropagated #13

Closed junyuchen245 closed 1 year ago

junyuchen245 commented 1 year ago

Hi @cwmok ,

Great and interesting work! However, I'm afraid there may be a potential error. I believe the mask loss (i.e., the OCC loss) cannot be backpropagated. This is due to the comparison you made: https://github.com/cwmok/DIRAC/blob/f860196479a59be73e5400c90beb9ddf31ba4493/Code/BRATS_train_DIRAC.py#L251-L252

The comparison operator (">") is non-differentiable and not continuous: https://stackoverflow.com/questions/75970265/gradient-cannot-be-back-propagated-due-to-comparison-operator-in-pytorch

You can do print(loss_multiNCC.requires_grad, loss_occ.requires_grad) to check if the loss functions allow backpropagation, which would give you True, False.

For this reason, the mask loss basically operates as a random number generator within the overall loss function, making it meaningless.

Thanks, Junyu

cwmok commented 1 year ago

Hi @junyuchen245,

Thanks for your interest in our work. You are correct. It seems that the loss is redundant here and didn't work like I expected. Alternatively, I have changed the occ loss as below:

        occ_xy_l = F.relu(smo_norm_diff_fw - thresh_fw) * 1000.
        occ_yx_l = F.relu(smo_norm_diff_bw - thresh_bw) * 1000.
        loss_occ = torch. Mean(occ_xy_l) + torch.mean(occ_yx_l)

I will explore the effect of this new loss and report the results here, if the loss makes a big difference to the result. Thank you for pointing out the mistake. You are very helpful. I very much appreciate it.

Best regards, Tony

junyuchen245 commented 1 year ago

Thanks for the confirmation, Tony. Another approach you might consider is implementing a 'soft' function such as the sigmoid: torch.sigmoid(a-b). Alternatively, you could increase the function's hardness by using a customized sigmoid function like 1/(1+alpha^-x). Here, when alpha=e, it provides the standard sigmoid, but as alpha increases, the function becomes sharper.

Junyu

cwmok commented 1 year ago

That's very insightful. I will try it out. Thanks a lot.

Tony