YipengHu / label-reg

(This repo is no longer up-to-date. Any updates will be at https://github.com/DeepRegNet/DeepReg/) A demo of the re-factored label-driven registration code, based on "Weakly-supervised convolutional neural networks for multimodal image registration"
Apache License 2.0
117 stars 32 forks source link

training for larger deformations on public datasets #13

Closed mattiaspaul closed 5 years ago

mattiaspaul commented 5 years ago

First of all, thanks a lot for making your research code publicly available. The installation is flawless and running the inference on the provided examples works fine. However, training the models for new datasets has so far yielded disappointing results. A number of PhD students in my group have e.g. attempted to run the code for whole heart MR to CT registration using the public MMWHS dataset without achieving any meaningful improvement over an initial affine alignment. 1) Would it be possible to check the training code for bugs or is it the case that larger deformations cannot be sufficiently estimated using LabelReg? We have spent a lot of time on trying to resolve these issues with different preprocessing and settings without success and feel it is important to get feedback from your side on how to best train on new data. 2) In addition there is an issue with the 'composite' option (both global and local transforms), which throws the following error:

Traceback (most recent call last):
  File "label-reg/training.py", line 35, in <module>
    image_fixed=input_fixed_image)
  File "~/label-reg/labelreg/networks.py", line 13, in build_network
    return CompositeNet(**kwargs)
  File "~/label-reg/labelreg/networks.py", line 89, in __init__
    image_moving=global_net.warp_image(),
TypeError: warp_image() missing 1 required positional argument: 'input_' 

3) The provided inference function fails to correctly warp multi-channel inputs (as required for label images). We have written a bug-fix for this, using pytorch's grid_sample function and would be happy to provide this to the interested community. Thanks a lot for your time.

YipengHu commented 5 years ago

Thanks for the feedback Mattias @mattiaspaul ,

First of all, I am very interested in generalise this to other applications! So far it "worked" with the prostate example in the publication and a brain image data set.

1 - I will make an effort to check the potential bugs. Admittedly this is not the code I use everyday (and therefore tested more) for my work, is there obvious thing you noticed could be causing or hinting bugs? Otherwise, can you give some more details? When you say "meaningful improvement", did you mean no improved registration during training or testing? how did you pre-process / split your data? How many types of labels did you use?

2 - Thanks! bug fixed see #14. I'd like to note that I stopped using the composite network as it didn't give convincingly better results after our preliminary conference work and seemly uses more memory. As shown in the media paper, I just use the local_net.

3 - Thanks again! issue created - will look into it. I did something similar in the dev branch but didnt come across this though...

YipengHu commented 5 years ago

oh - I don't think it is because the deformation is large. for the prostate work, I didn't use any initial alignment.

mattiaspaul commented 5 years ago

Dear Yipeng, thanks a lot for your answers and sorry for the criticisms: the code works reasonably well on some datasets, but just not as well as expected. So any guidance would be appreciated. To have a reproducible experiment on publicly available data I ran more tests on the medical decathlon (https://drive.google.com/drive/folders/1HqEgzS8BV2c7xYNrZdEAnrHk7osJJ--2) Task 2 MRI Heart. This contains 20 labelled images (with a single structure), which can be easily pre-processed into 160^3 volumes and split into 15 training and 5 test images. To enable pair-wise inter-subject registration I changed line 75 in training.py to

    case_indices2 = case_indices[1:]+case_indices[:1]

and accordingly used

    trainFeed = {ph_moving_image: reader_moving_image.get_data(case_indices),
                 ph_fixed_image: reader_moving_image.get_data(case_indices2),
                 ph_moving_label: reader_moving_label.get_data(case_indices, label_indices),
                 ph_fixed_label: reader_moving_label.get_data(case_indices2, label_indices),

I had to reduce the number of (base) feature channels to 16 because using batch-size of 3 this already took 24 GByte of GPU RAM and I don't have access to a card larger than 32 GBytes.

mattiaspaul commented 5 years ago

The results of the test on the remaining 5 scans (20 pair-wise registrations) was as follows: No-Reg: avg Dice: 52.2% LabelReg: avg Dice: 73.2% (stddev of Jacobian=0.71) this is a considerable improvement, however running our unsupervised discrete registration algorithm deeds (https://github.com/mattiaspaul/deedsBCV) I achieved an avg. Dice of 87.0% (with stddev of Jacobian=0.45 and run-time on CPU <15 secs.) So it would be great if you could check the same experiment with your private code and see whether it produces comparable results and/or maybe suggest more suitable settings. Thanks.

YipengHu commented 5 years ago

Criticisms are always most welcome! ;)

We actually have tried this on two different data sets with <50 training volumes, so far, little success at all. So I'm glad you got it to ~70%. I think the reasons are:

Last diagnostic question: did you get your Dice very high (~95%) during training, as this would be a good indication that the code is working but the network is over-fitted?

YipengHu commented 5 years ago

I'm closing this though searching for a large public data set for demo continues...