Project-MONAI / MONAI

AI Toolkit for Healthcare Imaging
https://monai.io/
Apache License 2.0
5.88k stars 1.09k forks source link

`bias_downsample=False` ResNet constructor #6811

Closed wyli closed 1 year ago

wyli commented 1 year ago

Currently, the bias_downsample=False argument is contradicting with pretraining, as it is hard coded to be not pretrained in ResNet constructor:

model: ResNet = ResNet(block, layers, block_inplanes, bias_downsample=not pretrained, **kwargs)
    if pretrained:
        # Author of paper zipped the state_dict on googledrive,
        # so would need to download, unzip and read (2.8gb file for a ~150mb state dict).
        # Would like to load dict from url but need somewhere to save the state dicts.
        raise NotImplementedError(
            "Currently not implemented. You need to manually download weights provided by the paper's author"
            " and load then to the model with `state_dict`. See https://github.com/Tencent/MedicalNet"
        )
    return model

When manually loading MedicalNet weights, the downsample bias terms raise errors as they are not present in the loaded weights. It is also not possible to remove bias_downsample by setting pretrained=True, this raises NotImplementedError. So, can you please remove the hard coding from the model constructor in the source code?

Originally posted by @acerdur in https://github.com/Project-MONAI/MONAI/issues/5477#issuecomment-1660043864

surajpaib commented 1 year ago

Hi @wyli, I can take a look at this!

Possibly two solutions exist here:

  1. We remove the hardcoded bias_downsample as the pretrained model would never be loaded in this call (currently) in actuality. If the Med3D weights are actually loaded later on, then we could revert this.

  2. Should the pretrained flag be removed completely? If this way of loading models is not a priority, then removing this flag from the constructor reduces any confusion.

Let me know how to proceed.

wyli commented 1 year ago

thanks @surajpaib, I remember the ability to load pre-trained weights is a popular feature for this network, perhaps option 1 is better as it shows the potential usage. (including a more step by step guide about loading the pretrained weights would be even better if you have time...)

wyli commented 1 year ago

comments from https://github.com/Project-MONAI/MONAI/pull/5477#issuecomment-1667607577

Hi @acerdur @wyli Since there were no error logs provided in the above comments, I'm assuming the issue comes only from the shortcut_type used.

Details: The sole purpose of passing bias_downsample=False is to match with the MedNet and official PyTorch implementation of ResNet which sets the bias=False in the downsampling layer. As for the downsampling layer, there are two variants in MONAI: shortcut_type='B' # uses a conv1x1 as downsampling layer shortcut_type='A' # uses a avgpool1x1 as downsampling layer

AFAIU the error you are facing comes from the shortcut layer. To correctly load the pretrained weights of MedNet, you should initialize the model with the correct achitecture with shortcut_type='A' for MedNet:

from monai.networks import nets
# MONAI ResNet18
net = nets.resnet18(pretrained=False, spatial_dims=3, n_input_channels=1, num_classes=2, shortcut_type='A')
wt_path = 'resnet_18.pth'  # path to weights from Google Drive of Tencent
pretrained_weights = torch.load(f=wt_path , map_location=device)

# match the keys 
weights = OrderedDict()
for k, v in pretrained_weights['state_dict'].items():
    weights.update({k.replace('module.', ''): v})

net.load_state_dict(weights, strict=False)  # _IncompatibleKeys(missing_keys=['fc.weight', 'fc.bias'], unexpected_keys=[])

The only pair of incompatible keys are for the last linear layer, which is expected for finetuning, and not provided/inferable in the MedNet weights.

Related: #6811

surajpaib commented 1 year ago

Hi @wyli In reference to the comment above, it looks like different arguments need to be passed to the ResNetconstructor based on the configuration of the pre-trained resnet to be loaded as specified by MedicalNet: https://github.com/Tencent/MedicalNet/blob/18c8bb6cd564eb1b964bffef1f4c2283f1ae6e7b/README.md?plain=1#L25

resnet_10_23dataset.pth: --model resnet --model_depth 10 --resnet_shortcut B
resnet_18_23dataset.pth: --model resnet --model_depth 18 --resnet_shortcut A
resnet_34_23dataset.pth: --model resnet --model_depth 34 --resnet_shortcut A
resnet_50_23dataset.pth: --model resnet --model_depth 50 --resnet_shortcut B

For all the configs with shortcut B, bias_downsample would need to be set to False, and for A, it wouldn't matter as the shortcut type is average pooling. But removing the hard-coded bias_downsample would still be essential!

P.S: I'm looking a bit more at this implementation and it seems like we could also implement the pretrained model download (using something like gdown) and loading ( maybe we first download to a tmp directory, unzip and load_state_dict). Were there some foreseen challenges with doing this?

acerdur commented 1 year ago

Hi both,

thank you very much for handling this! Losing shortcut B functionality for pre-trained weights would be sub-optimal.

The last suggestion from @surajpaib would also be very handy! I can suggest using a hidden cache directory as torch.hub does for pre-trained models. To upload the weights from the linked implementation, one should just remove the model. header from the state_dict and then it goes very smoothly until the last linear layer.

Again, thanks for changing the original issue.