facebookresearch / simsiam

PyTorch implementation of SimSiam https//arxiv.org/abs/2011.10566
Other
1.15k stars 176 forks source link

AssertionError: assert set(msg.missing_keys) == {"fc.weight", "fc.bias"} #8

Closed etetteh closed 3 years ago

etetteh commented 3 years ago

My pretraining phase went well, but now when I try loading the checkpoint to train the classifier it breaks. The following is what I am doing, which is based on the code in this repo:

model = torchvision.models.__dict__['resnet18']()
for name, param in model.named_parameters():
    if name not in ['fc.weight', 'fc.bias']:
        param.requires_grad = False
# init the fc layer
model.fc.weight.data.normal_(mean=0.0, std=0.01)
model.fc.bias.data.zero_()
### Load checkpoint
checkpoint = torch.load('./simsiam_malaria/resnet18_checkpoint.pth.tar', map_location="cpu")
for k in list(state_dict.keys()):
    print(k)
encoder.bn1.weight
encoder.bn1.bias
encoder.bn1.running_mean
encoder.bn1.running_var
encoder.bn1.num_batches_tracked
encoder.layer1.0.conv1.weight
encoder.layer1.0.bn1.weight
encoder.layer1.0.bn1.bias
encoder.layer1.0.bn1.running_mean
encoder.layer1.0.bn1.running_var
encoder.layer1.0.bn1.num_batches_tracked
encoder.layer1.0.conv2.weight
encoder.layer1.0.bn2.weight
encoder.layer1.0.bn2.bias
encoder.layer1.0.bn2.running_mean
encoder.layer1.0.bn2.running_var
encoder.layer1.0.bn2.num_batches_tracked
encoder.layer1.1.conv1.weight
encoder.layer1.1.bn1.weight
encoder.layer1.1.bn1.bias
encoder.layer1.1.bn1.running_mean
encoder.layer1.1.bn1.running_var
encoder.layer1.1.bn1.num_batches_tracked
encoder.layer1.1.conv2.weight
encoder.layer1.1.bn2.weight
encoder.layer1.1.bn2.bias
encoder.layer1.1.bn2.running_mean
encoder.layer1.1.bn2.running_var
encoder.layer1.1.bn2.num_batches_tracked
encoder.layer2.0.conv1.weight
encoder.layer2.0.bn1.weight
encoder.layer2.0.bn1.bias
encoder.layer2.0.bn1.running_mean
encoder.layer2.0.bn1.running_var
encoder.layer2.0.bn1.num_batches_tracked
encoder.layer2.0.conv2.weight
encoder.layer2.0.bn2.weight
encoder.layer2.0.bn2.bias
encoder.layer2.0.bn2.running_mean
encoder.layer2.0.bn2.running_var
encoder.layer2.0.bn2.num_batches_tracked
encoder.layer2.0.downsample.0.weight
encoder.layer2.0.downsample.1.weight
encoder.layer2.0.downsample.1.bias
encoder.layer2.0.downsample.1.running_mean
encoder.layer2.0.downsample.1.running_var
encoder.layer2.0.downsample.1.num_batches_tracked
encoder.layer2.1.conv1.weight
encoder.layer2.1.bn1.weight
encoder.layer2.1.bn1.bias
encoder.layer2.1.bn1.running_mean
encoder.layer2.1.bn1.running_var
encoder.layer2.1.bn1.num_batches_tracked
encoder.layer2.1.conv2.weight
encoder.layer2.1.bn2.weight
encoder.layer2.1.bn2.bias
encoder.layer2.1.bn2.running_mean
encoder.layer2.1.bn2.running_var
encoder.layer2.1.bn2.num_batches_tracked
encoder.layer3.0.conv1.weight
encoder.layer3.0.bn1.weight
encoder.layer3.0.bn1.bias
encoder.layer3.0.bn1.running_mean
encoder.layer3.0.bn1.running_var
encoder.layer3.0.bn1.num_batches_tracked
encoder.layer3.0.conv2.weight
encoder.layer3.0.bn2.weight
encoder.layer3.0.bn2.bias
encoder.layer3.0.bn2.running_mean
encoder.layer3.0.bn2.running_var
encoder.layer3.0.bn2.num_batches_tracked
encoder.layer3.0.downsample.0.weight
encoder.layer3.0.downsample.1.weight
encoder.layer3.0.downsample.1.bias
encoder.layer3.0.downsample.1.running_mean
encoder.layer3.0.downsample.1.running_var
encoder.layer3.0.downsample.1.num_batches_tracked
encoder.layer3.1.conv1.weight
encoder.layer3.1.bn1.weight
encoder.layer3.1.bn1.bias
encoder.layer3.1.bn1.running_mean
encoder.layer3.1.bn1.running_var
encoder.layer3.1.bn1.num_batches_tracked
encoder.layer3.1.conv2.weight
encoder.layer3.1.bn2.weight
encoder.layer3.1.bn2.bias
encoder.layer3.1.bn2.running_mean
encoder.layer3.1.bn2.running_var
encoder.layer3.1.bn2.num_batches_tracked
encoder.layer4.0.conv1.weight
encoder.layer4.0.bn1.weight
encoder.layer4.0.bn1.bias
encoder.layer4.0.bn1.running_mean
encoder.layer4.0.bn1.running_var
encoder.layer4.0.bn1.num_batches_tracked
encoder.layer4.0.conv2.weight
encoder.layer4.0.bn2.weight
encoder.layer4.0.bn2.bias
encoder.layer4.0.bn2.running_mean
encoder.layer4.0.bn2.running_var
encoder.layer4.0.bn2.num_batches_tracked
encoder.layer4.0.downsample.0.weight
encoder.layer4.0.downsample.1.weight
encoder.layer4.0.downsample.1.bias
encoder.layer4.0.downsample.1.running_mean
encoder.layer4.0.downsample.1.running_var
encoder.layer4.0.downsample.1.num_batches_tracked
encoder.layer4.1.conv1.weight
encoder.layer4.1.bn1.weight
encoder.layer4.1.bn1.bias
encoder.layer4.1.bn1.running_mean
encoder.layer4.1.bn1.running_var
encoder.layer4.1.bn1.num_batches_tracked
encoder.layer4.1.conv2.weight
encoder.layer4.1.bn2.weight
encoder.layer4.1.bn2.bias
encoder.layer4.1.bn2.running_mean
encoder.layer4.1.bn2.running_var
encoder.layer4.1.bn2.num_batches_tracked
encoder.fc.0.weight
encoder.fc.1.weight
encoder.fc.1.bias
encoder.fc.1.running_mean
encoder.fc.1.running_var
encoder.fc.1.num_batches_tracked
encoder.fc.3.weight
encoder.fc.4.weight
encoder.fc.4.bias
encoder.fc.4.running_mean
encoder.fc.4.running_var
encoder.fc.4.num_batches_tracked
encoder.fc.6.weight
encoder.fc.6.bias
encoder.fc.7.running_mean
encoder.fc.7.running_var
encoder.fc.7.num_batches_tracked
predictor.0.weight
predictor.1.weight
predictor.1.bias
predictor.1.running_mean
predictor.1.running_var
predictor.1.num_batches_tracked
predictor.3.weight
predictor.3.bias

Now when I run the:

for k in list(state_dict.keys()):
    # retain only encoder up to before the embedding layer
    if k.startswith('module.encoder') and not k.startswith('module.encoder.fc'):
        # remove prefix
        state_dict[k[len("module.encoder."):]] = state_dict[k]
    # delete renamed or unused k
    del state_dict[k]
msg = model.load_state_dict(state_dict, strict=False)
assert set(msg.missing_keys) == {"fc.weight", "fc.bias"}

I get the output:

AssertionError                            Traceback (most recent call last)
<ipython-input-7-3429f0c9d366> in <module>
      8     del state_dict[k]
      9 msg = model.load_state_dict(state_dict, strict=False)
---> 10 assert set(msg.missing_keys) == {"fc.weight", "fc.bias"}

AssertionError: 

The del state_dict[k] deletes every single key

endernewton commented 3 years ago

When loading the model, did you specify resnet-18 as well? I have not run resnet-18 myself but if you specify the architecture during pre-training, it has to be specified during linear eval as well.

etetteh commented 3 years ago

Okay. I have a separate script for the evaluation and I specified and loaded the pretrained weights. However, "the renaming of the pretrained keys" causes the issue.

On Thu., Jul. 8, 2021, 22:52 Xinlei Chen, @.***> wrote:

When loading the model, did you specify resnet-18 as well? I have not run resnet-18 myself but if you specify the architecture during pre-training, it has to be specified during linear eval as well.

— You are receiving this because you authored the thread. Reply to this email directly, view it on GitHub https://github.com/facebookresearch/simsiam/issues/8#issuecomment-876873648, or unsubscribe https://github.com/notifications/unsubscribe-auth/AGZQ72AKSBWGPUCL53CEVJ3TWZP73ANCNFSM47YQ5IIQ .

dikapiliao commented 2 years ago

Hello, how did you solve this problem? I also encountered the same problem