cwmok / C2FViT

This is the official Pytorch implementation of "Affine Medical Image Registration with Coarse-to-Fine Vision Transformer" (CVPR 2022), written by Tony C. W. Mok and Albert C. S. Chung.
MIT License
138 stars 5 forks source link

Dice mean of the validation set is decreased, and dose the Brain Dice mean when the training step increases #9

Closed pypi20200320 closed 1 year ago

pypi20200320 commented 1 year ago

Thanks for your excellent work. But I encounted some problem. First I trained a C2FViT model use my own dataset based on the script ,Train_C2FViT_pairwise.py. But there is no downward trend of the NCC-loss and the Dice mean of the validation set is decreased. I think maybe NCC is not suit for my dataset(which is CBCT2CT),so I use MSE-loss replace of NCC-loss, then the loss is decreasing and dice mean of validation set is increasing, from 0.74 to 0.86.

It's so confused.

And then I trained a C2FViT model in an unsupervised manner for pairwise registration with the OASIS dataset use the original script, Train_C2FViT_pairwise.py. But the Dice mean of the validation set is gradually decreased, and dose the Brain Dice mean when the training step increases. I can't figure out what's the problem because I totally have not changed your code.

Can you figure out why is this situation happening ?

cwmok commented 1 year ago

Hi @pypi20200320,

From my experience, there could be two reasons. 1) It is caused by the instability of NCC, see https://github.com/Project-MONAI/MONAI/discussions/3463. In that case, increasing the eps value of the NCC loss may help. e.g., eps=1e-4 2) It is caused by the cuda version of your pytorch. In my cases, I have encountered a similar problem when I use torch==2.0.0 with the default cuda version. Then, the training became stable when I switched to torch==2.0.0+cu117.

Hope it helps.