scaelles / DEXTR-PyTorch

Deep Extreme Cut http://www.vision.ee.ethz.ch/~cvlsegmentation/dextr
GNU General Public License v3.0
843 stars 153 forks source link

Fixed load_pretrained_ms function to load the pretrained weights properly #18

Closed ahmedhshahin closed 4 years ago

ahmedhshahin commented 5 years ago

Hello,

Thank you for sharing your code, it is very informative and useful. I found a potential bug and needed to fix it and contribute to your work. In the file deeplab_resnet.py, function load_pretrained_ms, this function is supposed to load the model and modify it if we are going to work on more than 3 channels by copying the weights of the third channel in the pre-trained model to the new channels. The logic of the function is correct, but for some reason, deepcopy does not actually copy the weights of the layers in PyTorch. You can make sure of that by some debugging - as I did - and printing out some statistics (mean) about the layer weights in both the pretrained model and the model we construct and copy the weights to. I found out that after calling the function weights remain unchanged, and that means the pretrained models are not loaded. I suggest a fix for that by iterating over state_dict of each model instead of iterating over the modules. I verified the solution and weights appeared to change properly.

UPDATE: I found another bug in custom_transforms.py file -> FixedResize class -- the original implementation deletes the original resolution of the image and gt from the sample. This raises an error at the testing time because we try to restore the original gt to crop back the prediction to the original resolution. I added an if condition to avoid deleting the gt from the sample. This needs to be done in testing time only and with test_batch_size of 1, I passed a flag to handle that.

Hope that helps, and thank you again for your contribution!