Cadene / pretrained-models.pytorch

Pretrained ConvNets for pytorch: NASNet, ResNeXt, ResNet, InceptionV4, InceptionResnetV2, Xception, DPN, etc.
BSD 3-Clause "New" or "Revised" License
8.98k stars 1.84k forks source link

Error: 'DenseNet' object has no attribute 'logits' #113

Open iWangLin opened 5 years ago

iWangLin commented 5 years ago

I used the pretrainedmodels.densent121 to trian my own dataset, but when I load the model I trianed, I get the error: 'DenseNet' object has no attribute 'logits', I wnant to know why? import pretrainedmodels model_daibao = torch.load('densenet121_daibao_model_1') `AttributeError Traceback (most recent call last)

in () 1 import pretrainedmodels ----> 2 model_daibao = torch.load('densenet121_daibao_model_1') 3 model_kouzhao = torch.load('densenet121_kouzhao_model_2') 4 model_yanjing = torch.load('densenet121_yanjing_model_2') 5 model_maozi = torch.load('densenet_maozi_model_1') /usr/local/lib/python3.6/dist-packages/torch/serialization.py in load(f, map_location, pickle_module) 365 deserialized_objects = {} 366 --> 367 if map_location is None: 368 restore_location = default_restore_location 369 elif isinstance(map_location, dict): /usr/local/lib/python3.6/dist-packages/torch/serialization.py in _load(f, map_location, pickle_module) 536 if protocol_version != PROTOCOL_VERSION: 537 raise RuntimeError("Invalid protocol version: %s" % protocol_version) --> 538 539 _sys_info = pickle_module.load(f) 540 unpickler = pickle_module.Unpickler(f) /usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in __getattr__(self, name) 533 elif params is not None and name in params: 534 if value is not None: --> 535 raise TypeError("cannot assign '{}' as parameter '{}' " 536 "(torch.nn.Parameter or None expected)" 537 .format(torch.typename(value), name)) AttributeError: 'DenseNet' object has no attribute 'logits'`
Marco2018 commented 5 years ago

I have the same problem with Resnet just like you, but I don't know how to solve this problem.

Cadene commented 5 years ago

Hi,

You should save your model using the state dictionary: https://pytorch.org/tutorials/beginner/saving_loading_models.html