xi-jia / LKU-Net

The official implementation of U-Net vs Transformer: Is U-Net Outdated in Medical Image Registration?
65 stars 8 forks source link

Details about reproduce the results of LKU-Net on OASIS dataset #5

Open AXu0511 opened 1 year ago

AXu0511 commented 1 year ago

Hello, thank you very much for your outstanding work, I was recently reproducing your results. Did you train the half-resolution displacements directly or train the a full-resolution displacement first then downsamples the displacement? By the way, could you be so kind to share the OASIS version code?

xi-jia commented 1 year ago

We tried and updated both the full- and half-resolution code. The pre-trained half-resolution models are also released and can be found in the readme.

rohitrango commented 6 months ago

Hi there,

Thanks for the work. I tried re-training the neurite OASIS dataset, and the full-resolution training collapses after a few iterations. Here is what the training log looks like:

image

I have used the default parameters provided in the file. Did you face any similar problems while training? Thanks in advance for the help.

xi-jia commented 6 months ago

@rohitrango, the corrected weights can be found in the command.txt, with --data_labda 1.0 (not used in code actually) --smth_labda 0.01 (is different from the one in train.py) --mask__labda 1.0 (as default in the train.py)

loss = 1 loss1 + 0.01 loss2 + 1.0 * loss3

rohitrango commented 6 months ago

Yes, I noticed that data_labda was not used, and I added it back (using 1.0) as of now.

I've changed smth_labda and start_channels appropriately from command.txt. I also wanted to use cross correlation instead of L2 loss (using_l2 = 2), is that fine or does it make the training unstable?

I've started training with the config you mentioned, fingers crossed!

xi-jia commented 6 months ago

@rohitrango

use cross correlation instead of L2 loss (using_l2 = 2), is that fine?

Yes. For NCC to achieve a satisfying Dice, the smth_labda should be much larger than 0.01 used for MSE. I vaguely remember we tuned a few sets of parameters like --data_labda 1.0 --smth_labda 0.1 --masklabda 1.0 or --data_labda 1.0 --smth_labda 0.5 --masklabda 1.0, but the optimal Dice of NCC is 1%-2% lower than that of MSE for this specific dataset.

rohitrango commented 6 months ago

Sounds good. Thanks for the info.