Closed wyli closed 1 year ago
Hi @wyli, I can take a look at this!
Possibly two solutions exist here:
We remove the hardcoded bias_downsample
as the pretrained
model would never be loaded in this call (currently) in actuality. If the Med3D
weights are actually loaded later on, then we could revert this.
Should the pretrained
flag be removed completely? If this way of loading models is not a priority, then removing this flag from the constructor reduces any confusion.
Let me know how to proceed.
thanks @surajpaib, I remember the ability to load pre-trained weights is a popular feature for this network, perhaps option 1 is better as it shows the potential usage. (including a more step by step guide about loading the pretrained weights would be even better if you have time...)
comments from https://github.com/Project-MONAI/MONAI/pull/5477#issuecomment-1667607577
Hi @acerdur @wyli Since there were no error logs provided in the above comments, I'm assuming the issue comes only from the
shortcut_type
used.Details: The sole purpose of passing
bias_downsample=False
is to match with the MedNet and official PyTorch implementation of ResNet which sets thebias=False
in the downsampling layer. As for the downsampling layer, there are two variants in MONAI:shortcut_type='B'
# uses a conv1x1 as downsampling layershortcut_type='A'
# uses a avgpool1x1 as downsampling layerAFAIU the error you are facing comes from the shortcut layer. To correctly load the pretrained weights of MedNet, you should initialize the model with the correct achitecture with
shortcut_type='A'
for MedNet:from monai.networks import nets # MONAI ResNet18 net = nets.resnet18(pretrained=False, spatial_dims=3, n_input_channels=1, num_classes=2, shortcut_type='A') wt_path = 'resnet_18.pth' # path to weights from Google Drive of Tencent pretrained_weights = torch.load(f=wt_path , map_location=device) # match the keys weights = OrderedDict() for k, v in pretrained_weights['state_dict'].items(): weights.update({k.replace('module.', ''): v}) net.load_state_dict(weights, strict=False) # _IncompatibleKeys(missing_keys=['fc.weight', 'fc.bias'], unexpected_keys=[])
The only pair of incompatible keys are for the last linear layer, which is expected for finetuning, and not provided/inferable in the MedNet weights.
Related: #6811
Hi @wyli
In reference to the comment above, it looks like different arguments need to be passed to the ResNet
constructor based on the configuration of the pre-trained resnet to be loaded as specified by MedicalNet
:
https://github.com/Tencent/MedicalNet/blob/18c8bb6cd564eb1b964bffef1f4c2283f1ae6e7b/README.md?plain=1#L25
resnet_10_23dataset.pth: --model resnet --model_depth 10 --resnet_shortcut B
resnet_18_23dataset.pth: --model resnet --model_depth 18 --resnet_shortcut A
resnet_34_23dataset.pth: --model resnet --model_depth 34 --resnet_shortcut A
resnet_50_23dataset.pth: --model resnet --model_depth 50 --resnet_shortcut B
For all the configs with shortcut B
, bias_downsample
would need to be set to False
, and for A
, it wouldn't matter as the shortcut type is average pooling. But removing the hard-coded bias_downsample
would still be essential!
P.S: I'm looking a bit more at this implementation and it seems like we could also implement the pretrained
model download (using something like gdown
) and loading ( maybe we first download to a tmp directory, unzip and load_state_dict
). Were there some foreseen challenges with doing this?
Hi both,
thank you very much for handling this! Losing shortcut B
functionality for pre-trained weights would be sub-optimal.
The last suggestion from @surajpaib would also be very handy! I can suggest using a hidden cache directory as torch.hub
does for pre-trained models.
To upload the weights from the linked implementation, one should just remove the model.
header from the state_dict
and then it goes very smoothly until the last linear layer.
Again, thanks for changing the original issue.
Currently, the
bias_downsample=False
argument is contradicting with pretraining, as it is hard coded to benot pretrained
in ResNet constructor:When manually loading MedicalNet weights, the downsample bias terms raise errors as they are not present in the loaded weights. It is also not possible to remove
bias_downsample
by settingpretrained=True
, this raisesNotImplementedError
. So, can you please remove the hard coding from the model constructor in the source code?Originally posted by @acerdur in https://github.com/Project-MONAI/MONAI/issues/5477#issuecomment-1660043864