kuangliu / pytorch-retinanet

RetinaNet in PyTorch
992 stars 250 forks source link

RuntimeError: Error(s) in loading state_dict for RetinaNet: Unexpected key(s) in state_dict #61

Closed sunshine-zkf closed 5 years ago

sunshine-zkf commented 5 years ago

When i run python train.py, found a problem as follow:

==> Preparing data.. Traceback (most recent call last): File "train.py", line 50, in net.load_state_dict(torch.load('./model/net.pth')) #加载模型参数 File "/home/sunshine_zkf/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 721, in load_state_dict self.class.name, "\n\t".join(error_msgs))) RuntimeError: Error(s) in loading state_dict for RetinaNet: Unexpected key(s) in state_dict: "fpn.toplayer3.weight", "fpn.toplayer3.bias", "fpn.bn1.num_batches_tracked", "fpn.layer1.0.bn1.num_batches_tracked", "fpn.layer1.0.bn2.num_batches_tracked", "fpn.layer1.0.bn3.num_batches_tracked", "fpn.layer1.0.downsample.1.num_batches_tracked", "fpn.layer1.1.bn1.num_batches_tracked", "fpn.layer1.1.bn2.num_batches_tracked", "fpn.layer1.1.bn3.num_batches_tracked", "fpn.layer1.2.bn1.num_batches_tracked", "fpn.layer1.2.bn2.num_batches_tracked", "fpn.layer1.2.bn3.num_batches_tracked", "fpn.layer2.0.bn1.num_batches_tracked", "fpn.layer2.0.bn2.num_batches_tracked", "fpn.layer2.0.bn3.num_batches_tracked", "fpn.layer2.0.downsample.1.num_batches_tracked", "fpn.layer2.1.bn1.num_batches_tracked", "fpn.layer2.1.bn2.num_batches_tracked", "fpn.layer2.1.bn3.num_batches_tracked", "fpn.layer2.2.bn1.num_batches_tracked", "fpn.layer2.2.bn2.num_batches_tracked", "fpn.layer2.2.bn3.num_batches_tracked", "fpn.layer2.3.bn1.num_batches_tracked", "fpn.layer2.3.bn2.num_batches_tracked", "fpn.layer2.3.bn3.num_batches_tracked", "fpn.layer3.0.bn1.num_batches_tracked", "fpn.layer3.0.bn2.num_batches_tracked", "fpn.layer3.0.bn3.num_batches_tracked", "fpn.layer3.0.downsample.1.num_batches_tracked", "fpn.layer3.1.bn1.num_batches_tracked", "fpn.layer3.1.bn2.num_batches_tracked", "fpn.layer3.1.bn3.num_batches_tracked", "fpn.layer3.2.bn1.num_batches_tracked", "fpn.layer3.2.bn2.num_batches_tracked", "fpn.layer3.2.bn3.num_batches_tracked", "fpn.layer3.3.bn1.num_batches_tracked", "fpn.layer3.3.bn2.num_batches_tracked", "fpn.layer3.3.bn3.num_batches_tracked", "fpn.layer3.4.bn1.num_batches_tracked", "fpn.layer3.4.bn2.num_batches_tracked", "fpn.layer3.4.bn3.num_batches_tracked", "fpn.layer3.5.bn1.num_batches_tracked", "fpn.layer3.5.bn2.num_batches_tracked", "fpn.layer3.5.bn3.num_batches_tracked", "fpn.layer4.0.bn1.num_batches_tracked", "fpn.layer4.0.bn2.num_batches_tracked", "fpn.layer4.0.bn3.num_batches_tracked", "fpn.layer4.0.downsample.1.num_batches_tracked", "fpn.layer4.1.bn1.num_batches_tracked", "fpn.layer4.1.bn2.num_batches_tracked", "fpn.layer4.1.bn3.num_batches_tracked", "fpn.layer4.2.bn1.num_batches_tracked", "fpn.layer4.2.bn2.num_batches_tracked", "fpn.layer4.2.bn3.num_batches_tracked".

how to fix it? can you give some advises? thank you @kuangliu

abdelrahman-0 commented 2 years ago

In my case the issue was caused by incompatible pretrained network weights. You can fix it either by: 1) not using the pretrained network: model = torchvision.models.detection.retinanet_resnet50_fpn(pretrained=False)

or

2) loading the network with pretrained=True AND use the default num_classes (91). If you're using a custom dataset (e.g: 10 classes), you can then switch out the classification head with a new one: model = torchvision.models.detection.retinanet_resnet50_fpn(pretrained=True) in_channels = model.backbone.out_channels num_anchors_per_location = model.anchor_generator.num_anchors_per_location()[0] model.head.classification_head = torchvision.models.detection.retinanet.RetinaNetClassificationHead(in_channels, num_anchors_per_location, num_classes=10)