BMEII-AI / RadImageNet

RadImageNet, a pre-trained convolutional neural networks trained solely from medical imaging to be used as the basis of transfer learning for medical imaging applications.
MIT License
340 stars 35 forks source link

Loading Pretrained Pytorch Inception_V3 Weights #16

Open asy51 opened 9 months ago

asy51 commented 9 months ago

I am getting state_dict mismatches:

class Backbone(nn.Module):
    def __init__(self, net='inceptionv3'):
        super().__init__()
        if net == 'inceptionv3':
            base_model = inception_v3()
        elif net == 'densenet121':
            base_model = densenet121()
        elif net == 'resnet50':
            base_model = resnet50()
        encoder_layers = list(base_model.children())
        self.backbone = nn.Sequential(*encoder_layers[:-1])

    def forward(self, x):
        return self.backbone(x)

net = 'inceptionv3'
backbone = Backbone(net)
backbone.load_state_dict(torch.load(RAD[net]))

Error msg:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[21], [line 18](vscode-notebook-cell:?execution_count=21&line=18)
     [16](vscode-notebook-cell:?execution_count=21&line=16) net = 'inceptionv3'
     [17](vscode-notebook-cell:?execution_count=21&line=17) backbone = Backbone(net)
---> [18](vscode-notebook-cell:?execution_count=21&line=18) backbone.load_state_dict(torch.load(RAD[net]))

...

RuntimeError: Error(s) in loading state_dict for Backbone:
    Missing key(s) in state_dict: "backbone.15.conv0.conv.weight", "backbone.15.conv0.bn.weight", "backbone.15.conv0.bn.bias", "backbone.15.conv0.bn.running_mean", "backbone.15.conv0.bn.running_var", "backbone.15.conv1.conv.weight", "backbone.15.conv1.bn.weight", "backbone.15.conv1.bn.bias", "backbone.15.conv1.bn.running_mean", "backbone.15.conv1.bn.running_var", "backbone.15.fc.weight", "backbone.15.fc.bias", "backbone.16.branch3x3_2.conv.weight", "backbone.16.branch3x3_2.bn.weight", "backbone.16.branch3x3_2.bn.bias", "backbone.16.branch3x3_2.bn.running_mean", "backbone.16.branch3x3_2.bn.running_var", "backbone.16.branch7x7x3_1.conv.weight", "backbone.16.branch7x7x3_1.bn.weight", "backbone.16.branch7x7x3_1.bn.bias", "backbone.16.branch7x7x3_1.bn.running_mean", "backbone.16.branch7x7x3_1.bn.running_var", "backbone.16.branch7x7x3_2.conv.weight", "backbone.16.branch7x7x3_2.bn.weight", "backbone.16.branch7x7x3_2.bn.bias", "backbone.16.branch7x7x3_2.bn.running_mean", "backbone.16.branch7x7x3_2.bn.running_var", "backbone.16.branch7x7x3_3.conv.weight", "backbone.16.branch7x7x3_3.bn.weight", "backbone.16.branch7x7x3_3.bn.bias", "backbone.16.branch7x7x3_3.bn.running_mean", "backbone.16.branch7x7x3_3.bn.running_var", "backbone.16.branch7x7x3_4.conv.weight", "backbone.16.branch7x7x3_4.bn.weight", "backbone.16.branch7x7x3_4.bn.bias", "backbone.16.branch7x7x3_4.bn.running_mean", "backbone.16.branch7x7x3_4.bn.running_var", "backbone.18.branch1x1.conv.weight", "backbone.18.branch1x1.bn.weight", "backbone.18.branch1x1.bn.bias", "backbone.18.branch1x1.bn.running_mean", "backbone.18.branch1x1.bn.running_var", "backbone.18.branch3x3_1.conv.weight", "backbone.18.branch3x3_1.bn.weight", "backbone.18.branch3x3_1.bn.bias", "backbone.18.branch3x3_1.bn.running_mean", "backbone.18.branch3x3_1.bn.running_var", "backbone.18.branch3x3_2a.conv.weight", "backbone.18.branch3x3_2a.bn.weight", "backbone.18.branch3x3_2a.bn.bias", "backbone.18.branch3x3_2a.bn.running_mean", "backbone.18.branch3x3_2a.bn.running_var", "backbone.18.branch3x3_2b.conv.weight", "backbone.18.branch3x3_2b.bn.weight", "backbone.18.branch3x3_2b.bn.bias", "backbone.18.branch3x3_2b.bn.running_mean", "backbone.18.branch3x3_2b.bn.running_var", "backbone.18.branch3x3dbl_1.conv.weight", "backbone.18.branch3x3dbl_1.bn.weight", "backbone.18.branch3x3dbl_1.bn.bias", "backbone.18.branch3x3dbl_1.bn.running_mean", "backbone.18.branch3x3dbl_1.bn.running_var", "backbone.18.branch3x3dbl_2.conv.weight", "backbone.18.branch3x3dbl_2.bn.weight", "backbone.18.branch3x3dbl_2.bn.bias", "backbone.18.branch3x3dbl_2.bn.running_mean", "backbone.18.branch3x3dbl_2.bn.running_var", "backbone.18.branch3x3dbl_3a.conv.weight", "backbone.18.branch3x3dbl_3a.bn.weight", "backbone.18.branch3x3dbl_3a.bn.bias", "backbone.18.branch3x3dbl_3a.bn.running_mean", "backbone.18.branch3x3dbl_3a.bn.running_var", "backbone.18.branch3x3dbl_3b.conv.weight", "backbone.18.branch3x3dbl_3b.bn.weight", "backbone.18.branch3x3dbl_3b.bn.bias", "backbone.18.branch3x3dbl_3b.bn.running_mean", "backbone.18.branch3x3dbl_3b.bn.running_var", "backbone.18.branch_pool.conv.weight", "backbone.18.branch_pool.bn.weight", "backbone.18.branch_pool.bn.bias", "backbone.18.branch_pool.bn.running_mean", "backbone.18.branch_pool.bn.running_var". 
    Unexpected key(s) in state_dict: "backbone.15.branch3x3_1.conv.weight", "backbone.15.branch3x3_1.bn.weight", "backbone.15.branch3x3_1.bn.bias", "backbone.15.branch3x3_1.bn.running_mean", "backbone.15.branch3x3_1.bn.running_var", "backbone.15.branch3x3_1.bn.num_batches_tracked", "backbone.15.branch3x3_2.conv.weight", "backbone.15.branch3x3_2.bn.weight", "backbone.15.branch3x3_2.bn.bias", "backbone.15.branch3x3_2.bn.running_mean", "backbone.15.branch3x3_2.bn.running_var", "backbone.15.branch3x3_2.bn.num_batches_tracked", "backbone.15.branch7x7x3_1.conv.weight", "backbone.15.branch7x7x3_1.bn.weight", "backbone.15.branch7x7x3_1.bn.bias", "backbone.15.branch7x7x3_1.bn.running_mean", "backbone.15.branch7x7x3_1.bn.running_var", "backbone.15.branch7x7x3_1.bn.num_batches_tracked", "backbone.15.branch7x7x3_2.conv.weight", "backbone.15.branch7x7x3_2.bn.weight", "backbone.15.branch7x7x3_2.bn.bias", "backbone.15.branch7x7x3_2.bn.running_mean", "backbone.15.branch7x7x3_2.bn.running_var", "backbone.15.branch7x7x3_2.bn.num_batches_tracked", "backbone.15.branch7x7x3_3.conv.weight", "backbone.15.branch7x7x3_3.bn.weight", "backbone.15.branch7x7x3_3.bn.bias", "backbone.15.branch7x7x3_3.bn.running_mean", "backbone.15.branch7x7x3_3.bn.running_var", "backbone.15.branch7x7x3_3.bn.num_batches_tracked", "backbone.15.branch7x7x3_4.conv.weight", "backbone.15.branch7x7x3_4.bn.weight", "backbone.15.branch7x7x3_4.bn.bias", "backbone.15.branch7x7x3_4.bn.running_mean", "backbone.15.branch7x7x3_4.bn.running_var", "backbone.15.branch7x7x3_4.bn.num_batches_tracked", "backbone.16.branch1x1.conv.weight", "backbone.16.branch1x1.bn.weight", "backbone.16.branch1x1.bn.bias", "backbone.16.branch1x1.bn.running_mean", "backbone.16.branch1x1.bn.running_var", "backbone.16.branch1x1.bn.num_batches_tracked", "backbone.16.branch3x3_2a.conv.weight", "backbone.16.branch3x3_2a.bn.weight", "backbone.16.branch3x3_2a.bn.bias", "backbone.16.branch3x3_2a.bn.running_mean", "backbone.16.branch3x3_2a.bn.running_var", "backbone.16.branch3x3_2a.bn.num_batches_tracked", "backbone.16.branch3x3_2b.conv.weight", "backbone.16.branch3x3_2b.bn.weight", "backbone.16.branch3x3_2b.bn.bias", "backbone.16.branch3x3_2b.bn.running_mean", "backbone.16.branch3x3_2b.bn.running_var", "backbone.16.branch3x3_2b.bn.num_batches_tracked", "backbone.16.branch3x3dbl_1.conv.weight", "backbone.16.branch3x3dbl_1.bn.weight", "backbone.16.branch3x3dbl_1.bn.bias", "backbone.16.branch3x3dbl_1.bn.running_mean", "backbone.16.branch3x3dbl_1.bn.running_var", "backbone.16.branch3x3dbl_1.bn.num_batches_tracked", "backbone.16.branch3x3dbl_2.conv.weight", "backbone.16.branch3x3dbl_2.bn.weight", "backbone.16.branch3x3dbl_2.bn.bias", "backbone.16.branch3x3dbl_2.bn.running_mean", "backbone.16.branch3x3dbl_2.bn.running_var", "backbone.16.branch3x3dbl_2.bn.num_batches_tracked", "backbone.16.branch3x3dbl_3a.conv.weight", "backbone.16.branch3x3dbl_3a.bn.weight", "backbone.16.branch3x3dbl_3a.bn.bias", "backbone.16.branch3x3dbl_3a.bn.running_mean", "backbone.16.branch3x3dbl_3a.bn.running_var", "backbone.16.branch3x3dbl_3a.bn.num_batches_tracked", "backbone.16.branch3x3dbl_3b.conv.weight", "backbone.16.branch3x3dbl_3b.bn.weight", "backbone.16.branch3x3dbl_3b.bn.bias", "backbone.16.branch3x3dbl_3b.bn.running_mean", "backbone.16.branch3x3dbl_3b.bn.running_var", "backbone.16.branch3x3dbl_3b.bn.num_batches_tracked", "backbone.16.branch_pool.conv.weight", "backbone.16.branch_pool.bn.weight", "backbone.16.branch_pool.bn.bias", "backbone.16.branch_pool.bn.running_mean", "backbone.16.branch_pool.bn.running_var", "backbone.16.branch_pool.bn.num_batches_tracked". 
    size mismatch for backbone.16.branch3x3_1.conv.weight: copying a param with shape torch.Size([384, 1280, 1, 1]) from checkpoint, the shape in current model is torch.Size([192, 768, 1, 1]).
    size mismatch for backbone.16.branch3x3_1.bn.weight: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
    size mismatch for backbone.16.branch3x3_1.bn.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
    size mismatch for backbone.16.branch3x3_1.bn.running_mean: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
    size mismatch for backbone.16.branch3x3_1.bn.running_var: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
    size mismatch for backbone.17.branch1x1.conv.weight: copying a param with shape torch.Size([320, 2048, 1, 1]) from checkpoint, the shape in current model is torch.Size([320, 1280, 1, 1]).
    size mismatch for backbone.17.branch3x3_1.conv.weight: copying a param with shape torch.Size([384, 2048, 1, 1]) from checkpoint, the shape in current model is torch.Size([384, 1280, 1, 1]).
    size mismatch for backbone.17.branch3x3dbl_1.conv.weight: copying a param with shape torch.Size([448, 2048, 1, 1]) from checkpoint, the shape in current model is torch.Size([448, 1280, 1, 1]).
    size mismatch for backbone.17.branch_pool.conv.weight: copying a param with shape torch.Size([192, 2048, 1, 1]) from checkpoint, the shape in current model is torch.Size([192, 1280, 1, 1]).

torch and torchvision versions: ('2.0.0+cu117', '0.15.1+cu117') Doesn't seem like conda environment.yaml or pip requirements.txt files are available Please advise on how to load the weights! 🙏

sRassmann commented 9 months ago

I think the models were obtained with inception_v3(aux_logits=False) (i.e. without the Auxiliary Classifiers).

Using torch v 2.1.1 and torchvision v 0.16.1 running

class Backbone(torch.nn.Module):
    def __init__(self, path):
        super().__init__()
        base_model = inception_v3(pretrained=False, aux_logits=False)
        encoder_layers = list(base_model.children())
        self.backbone = nn.Sequential(*encoder_layers[:-1])

        state_dict = torch.load(path)
        new_state_dict = {}
        for k, v in state_dict.items():
            new_state_dict[k[9:]] = v

        print(self.backbone.load_state_dict(new_state_dict))  # <All keys matched successfully>

    def forward(self, x):
        return self.backbone(x)

backbone = Backbone("path/to/InceptionV3.pt")

Works for me