abhishekkrthakur / bert-sentiment

MIT License
268 stars 103 forks source link

Getting issue while loading model #1

Closed vijender412 closed 4 years ago

vijender412 commented 4 years ago

Getting below issue while loading the model in local system. Model was trained on colab.

Traceback (most recent call last):
  File "app.py", line 74, in <module>
    MODEL.load_state_dict(torch.load(config.MODEL_PATH, map_location=torch.device('cpu'))) #New created
  File "C:\Users\Vijender\Downloads\bert_sentiment\lib\site-packages\torch\nn\modules\module.py", line 830, in load_state_dict
    self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for BERTBaseUncased:
        Missing key(s) in state_dict: "bert.embeddings.word_embeddings.weight", "bert.embeddings.position_embeddings.weight", "bert.embeddings.token_type_embeddings.weight", "bert.embeddings.LayerNorm.weight", "bert.embeddings.LayerNorm.bias", "bert.encoder.layer.0.attention.self.query.weight", "bert.encoder.layer.0.attention.self.query.bias",
abhishekkrthakur commented 4 years ago

did you use data parallel to train the model?

vijender412 commented 4 years ago

@abhishekkrthakur Yes used data parallel while training

abhishekkrthakur commented 4 years ago

does your bert_base_path has bert base uncased model files?

vijender412 commented 4 years ago

@abhishekkrthakur Fixed the issue was with data parallel. I local i was not making use of "MODEL = nn.DataParallel(MODEL)". Now working with this. Can you help me understand the use of DataParallel and if it does require after training also? Thanks for quick replying. Good to close the issue

abhishekkrthakur commented 4 years ago

DataParallel is used only when you have multiple GPUs during training. If you used it in training, you have to use in inference but there are other ways too.

Closing this issue for now. :)

nipunsadvilkar commented 4 years ago

@abhishekkrthakur : Can you give any leads on how to load DataParallel GPU model on CPU? As per pytorch docs tried following but still raises above RuntimeError

device = torch.device('cpu')
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))