billhhh / MnasNet-pytorch-pretrained

A pytorch pretrained model of MnasNet
20 stars 2 forks source link

Error when load loading state_dict #1

Closed chakkritte closed 5 years ago

chakkritte commented 5 years ago

Hello @billhhh

I used your code and pre-trained model to loading state_dict have to error,

RuntimeError: Error(s) in loading state_dict for MnasNet:
    Missing key(s) in state_dict: "features.0.0.weight", "features.0.1.weight", "features.0.1.running_mean", "features.0.1.bias", "features.0.1.running_var", "features.1.0.weight", "features.1.1.weight", "features.1.1.running_mean", "features.1.1.bias", "features.1.1.running_var", "features.1.3.weight", "features.1.4.weight", "features.1.4.running_mean", "features.1.4.bias", "features.1.4.running_var", "features.2.conv.0.weight", "features.2.conv.1.weight", "features.2.conv.1.running_mean", "features.2.conv.1.bias", "features.2.conv.1.running_var", "features.2.conv.3.weight", "features.2.conv.4.weight", "features.2.conv.4.running_mean", "features.2.conv.4.bias", "features.2.conv.4.running_var", "features.2.conv.6.weight", "features.2.conv.7.weight", "features.2.conv.7.running_mean", "features.2.conv.7.bias", "features.2.conv.7.running_var", "features.3.conv.0.weight", "features.3.conv.1.weight", "features.3.conv.1.running_mean", "features.3.conv.1.bias", "features.3.conv.1.running_var", "features.3.conv.3.weight", "features.3.conv.4.weight", "features.3.conv.4.running_mean", "features.3.conv.4.bias", "features.3.conv.4.running_var", "features.3.conv.6.weight", "features.3.conv.7.weight", "features.3.conv.7.running_mean", "features.3.conv.7.bias", "features.3.conv.7.running_var", "features.4.conv.0.weight", "features.4.conv.1.weight", "features.4.conv.1.running_mean", "features.4.conv.1.bias", "features.4.conv.1.running_var", "features.4.conv.3.weight", "features.4.conv.4.weight", "features.4.conv.4.running_mean", "features.4.conv.4.bias", "features.4.conv.4.running_var", "features.4.conv.6.weight", "features.4.conv.7.weight", "features.4.conv.7.running_mean", "features.4.conv.7.bias", "features.4.conv.7.running_var", "features.5.conv.0.weight", "features.5.conv.1.weight", "features.5.conv.1.running_mean", "features.5.conv.1.bias", "features.5.conv.1.running_var", "features.5.conv.3.weight", "features.5.conv.4.weight", "features.5.conv.4.running_mean", "features.5.conv.4.bias", "features.5.conv.4.running_var", "features.5.conv.6.weight", "features.5.conv.7.weight", "features.5.conv.7.running_mean", "features.5.conv.7.bias", "features.5.conv.7.running_var", "features.6.conv.0.weight", "features.6.conv.1.weight", "features.6.conv.1.running_mean", "features.6.conv.1.bias", "features.6.conv.1.running_var", "features.6.conv.3.weight", "features.6.conv.4.weight", "features.6.conv.4.running_mean", "features.6.conv.4.bias", "features.6.conv.4.running_var", "features.6.conv.6.weight", "features.6.conv.7.weight", "features.6.conv.7.running_mean", "features.6.conv.7.bias", "features.6.conv.7.running_var", "features.7.conv.0.weight", "features.7.conv.1.weight", "features.7.conv.1.running_mean", "features.7.conv.1.bias", "features.7.conv.1.running_var", "features.7.conv.3.weight", "features.7.conv.4.weight", "features.7.conv.4.running_mean", "features.7.conv.4.bias", "features.7.conv.4.running_var", "features.7.conv.6.weight", "features.7.conv.7.weight", "features.7.conv.7.running_mean", "features.7.conv.7.bias", "features.7.conv.7.running_var", "features.8.conv.0.weight", "features.8.conv.1.weight", "features.8.conv.1.running_mean", "features.8.conv.1.bias", "features.8.conv.1.running_var", "features.8.conv.3.weight", "features.8.conv.4.weight", "features.8.conv.4.running_mean", "features.8.conv.4.bias", "features.8.conv.4.running_var", "features.8.conv.6.weight", "features.8.conv.7.weight", "features.8.conv.7.running_mean", "features.8.conv.7.bias", "features.8.conv.7.running_var", "features.9.conv.0.weight", "features.9.conv.1.weight", "features.9.conv.1.running_mean", "features.9.conv.1.bias", "features.9.conv.1.running_var", "features.9.conv.3.weight", "features.9.conv.4.weight", "features.9.conv.4.running_mean", "features.9.conv.4.bias", "features.9.conv.4.running_var", "features.9.conv.6.weight", "features.9.conv.7.weight", "features.9.conv.7.running_mean", "features.9.conv.7.bias", "features.9.conv.7.running_var", "features.10.conv.0.weight", "features.10.conv.1.weight", "features.10.conv.1.running_mean", "features.10.conv.1.bias", "features.10.conv.1.running_var", "features.10.conv.3.weight", "features.10.conv.4.weight", "features.10.conv.4.running_mean", "features.10.conv.4.bias", "features.10.conv.4.running_var", "features.10.conv.6.weight", "features.10.conv.7.weight", "features.10.conv.7.running_mean", "features.10.conv.7.bias", "features.10.conv.7.running_var", "features.11.conv.0.weight", "features.11.conv.1.weight", "features.11.conv.1.running_mean", "features.11.conv.1.bias", "features.11.conv.1.running_var", "features.11.conv.3.weight", "features.11.conv.4.weight", "features.11.conv.4.running_mean", "features.11.conv.4.bias", "features.11.conv.4.running_var", "features.11.conv.6.weight", "features.11.conv.7.weight", "features.11.conv.7.running_mean", "features.11.conv.7.bias", "features.11.conv.7.running_var", "features.12.conv.0.weight", "features.12.conv.1.weight", "features.12.conv.1.running_mean", "features.12.conv.1.bias", "features.12.conv.1.running_var", "features.12.conv.3.weight", "features.12.conv.4.weight", "features.12.conv.4.running_mean", "features.12.conv.4.bias", "features.12.conv.4.running_var", "features.12.conv.6.weight", "features.12.conv.7.weight", "features.12.conv.7.running_mean", "features.12.conv.7.bias", "features.12.conv.7.running_var", "features.13.conv.0.weight", "features.13.conv.1.weight", "features.13.conv.1.running_mean", "features.13.conv.1.bias", "features.13.conv.1.running_var", "features.13.conv.3.weight", "features.13.conv.4.weight", "features.13.conv.4.running_mean", "features.13.conv.4.bias", "features.13.conv.4.running_var", "features.13.conv.6.weight", "features.13.conv.7.weight", "features.13.conv.7.running_mean", "features.13.conv.7.bias", "features.13.conv.7.running_var", "features.14.conv.0.weight", "features.14.conv.1.weight", "features.14.conv.1.running_mean", "features.14.conv.1.bias", "features.14.conv.1.running_var", "features.14.conv.3.weight", "features.14.conv.4.weight", "features.14.conv.4.running_mean", "features.14.conv.4.bias", "features.14.conv.4.running_var", "features.14.conv.6.weight", "features.14.conv.7.weight", "features.14.conv.7.running_mean", "features.14.conv.7.bias", "features.14.conv.7.running_var", "features.15.conv.0.weight", "features.15.conv.1.weight", "features.15.conv.1.running_mean", "features.15.conv.1.bias", "features.15.conv.1.running_var", "features.15.conv.3.weight", "features.15.conv.4.weight", "features.15.conv.4.running_mean", "features.15.conv.4.bias", "features.15.conv.4.running_var", "features.15.conv.6.weight", "features.15.conv.7.weight", "features.15.conv.7.running_mean", "features.15.conv.7.bias", "features.15.conv.7.running_var", "features.16.conv.0.weight", "features.16.conv.1.weight", "features.16.conv.1.running_mean", "features.16.conv.1.bias", "features.16.conv.1.running_var", "features.16.conv.3.weight", "features.16.conv.4.weight", "features.16.conv.4.running_mean", "features.16.conv.4.bias", "features.16.conv.4.running_var", "features.16.conv.6.weight", "features.16.conv.7.weight", "features.16.conv.7.running_mean", "features.16.conv.7.bias", "features.16.conv.7.running_var", "features.17.conv.0.weight", "features.17.conv.1.weight", "features.17.conv.1.running_mean", "features.17.conv.1.bias", "features.17.conv.1.running_var", "features.17.conv.3.weight", "features.17.conv.4.weight", "features.17.conv.4.running_mean", "features.17.conv.4.bias", "features.17.conv.4.running_var", "features.17.conv.6.weight", "features.17.conv.7.weight", "features.17.conv.7.running_mean", "features.17.conv.7.bias", "features.17.conv.7.running_var", "features.18.0.weight", "features.18.1.weight", "features.18.1.running_mean", "features.18.1.bias", "features.18.1.running_var", "classifier.1.weight", "classifier.1.bias". 

pls tell me about your full source code of training.

thank you

huxianer commented 5 years ago

@billhhh @chakkritte YES,I have met with the same problem ,and i think your code is not completely,could you share your completely code,Thanks!

billhhh commented 5 years ago

@chakkritte @huxianer The source code is complete, which is the one I used to train. Just figure out which model you guys have used, if using 299 input, plz change train.py input_size from 224 to 299

This problem I seems encountered before, maybe tried change the pretrained model name helps

huxianer commented 5 years ago

@billhhh I just use my train dataset,when I run the script,the loss is not change,and the other display is incorrect

billhhh commented 5 years ago

Because of dataparallel

chakkritte commented 5 years ago

Hi

@billhhh your pre-trained have a weight name is -> module.features.0.0.weight but in your source code is -> features.0.0.weight

I fixed this problem from https://discuss.pytorch.org/t/solved-keyerror-unexpected-key-module-encoder-embedding-weight-in-state-dict/1686/4 already.

@huxianer I think this solution will help you.

Put below source-code before load_state_dict

# original saved file with DataParallel
state_dict = torch.load('myfile.pth.tar')
# create new OrderedDict that does not contain `module.`
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] # remove `module.`
    new_state_dict[name] = v
# load params
model.load_state_dict(new_state_dict)