warmspringwinds / pytorch-segmentation-detection

Image Segmentation and Object Detection in Pytorch
MIT License
746 stars 170 forks source link

RuntimeError: Error(s) in loading_state_dict for VGG #22

Open ar13pit opened 5 years ago

ar13pit commented 5 years ago

By using your fork of torchvision and default installation of pytorch for Linux-Python3.6-CUDA10:

  1. init_weights argument in the class VGG was missing.

  2. After fixing (1), the following error was generated:

In [1]: from torchvision import models                                                                               
In [2]: model = models.vgg16(pretrained=True, fully_conv=True)                                                       

RuntimeError                              Traceback (most recent call last)
<ipython-input-2-802ee77a237c> in <module>
----> 1 model = models.vgg16(pretrained=True, fully_conv=True)

~/repositories/github/pytorch-segmentation-detection/vision/torchvision/models/vgg.py in vgg16(pretrained, **kwargs)
    164     model = VGG(make_layers(cfg['D']), **kwargs)
    165     if pretrained:
--> 166         model.load_state_dict(model_zoo.load_url(model_urls['vgg16']))
    167     return model
    168 

~/.local/lib/python3.6/site-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
    767         if len(error_msgs) > 0:
    768             raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
--> 769                                self.__class__.__name__, "\n\t".join(error_msgs)))
    770 
    771     def _named_members(self, get_members_fn, prefix='', recurse=True):

RuntimeError: Error(s) in loading state_dict for VGG:
    size mismatch for classifier.0.weight: copying a param with shape torch.Size([4096, 25088]) from checkpoint, the shape in current model is torch.Size([4096, 512, 7, 7]).
    size mismatch for classifier.3.weight: copying a param with shape torch.Size([4096, 4096]) from checkpoint, the shape in current model is torch.Size([4096, 4096, 1, 1]).
    size mismatch for classifier.6.weight: copying a param with shape torch.Size([1000, 4096]) from checkpoint, the shape in current model is torch.Size([1000, 4096, 1, 1]).