EricFillion / happy-transformer

Happy Transformer makes it easy to fine-tune and perform inference with NLP Transformer models.
http://happytransformer.com
Apache License 2.0
517 stars 66 forks source link

The state dictionary of the model you are training to load is corrupted. Are you sure it was properly saved? #261

Closed kenanEkici closed 2 years ago

kenanEkici commented 3 years ago

Hi,

After training BERT with happy_tc = HappyTextClassification(model_type="BERT", model_name="bert-base-uncased", num_labels=2), saving my model with happy_tc.save("model/") and loading it immediately after with happy_tc = HappyTextClassification(load_path="model"), I get the following error:

loading configuration file https://huggingface.co/distilbert-base-uncased/resolve/main/config.json from cache at /root/.cache/huggingface/transformers/23454919702d26495337f3da04d1655c7ee010d5ec9d77bdb9e399e00302c0a1.91b885ab15d631bf9cee9dc9d25ece0afd932f2f5130eba28f2055b2220c0333 Model config DistilBertConfig { "activation": "gelu", "architectures": [ "DistilBertForMaskedLM" ], "attention_dropout": 0.1, "dim": 768, "dropout": 0.1, "hidden_dim": 3072, "initializer_range": 0.02, "max_position_embeddings": 512, "model_type": "distilbert", "n_heads": 12, "n_layers": 6, "pad_token_id": 0, "qa_dropout": 0.1, "seq_classif_dropout": 0.2, "sinusoidal_pos_embds": false, "tieweights": true, "transformers_version": "4.10.2", "vocab_size": 30522 }

loading weights file model/pytorch_model.bin

ValueError Traceback (most recent call last)

in () ----> 1 happy_tc = HappyTextClassification(load_path="model") 3 frames /usr/local/lib/python3.7/dist-packages/transformers/modeling_utils.py in _load_state_dict_into_model(cls, model, state_dict, pretrained_model_name_or_path, ignore_mismatched_sizes, _fast_init) 1502 if any(key in expected_keys_not_prefixed for key in loaded_keys): 1503 raise ValueError( -> 1504 "The state dictionary of the model you are training to load is corrupted. Are you sure it was " 1505 "properly saved?" 1506 ) ValueError: The state dictionary of the model you are training to load is corrupted. Are you sure it was properly saved?
akshatshah21 commented 2 years ago

Firstly, thanks for this great wrapper!

I'm also getting the same error. Versions:

happytransformer   2.4.0
huggingface-hub    0.2.1
tokenizers         0.10.3
torch              1.10.1+cpu
transformers       4.15.0

I'm simply loading this model from Huggingface and then saving it without any changes:

classifier = HappyTextClassification("BERT", "Hate-speech-CNERG/dehatebert-mono-english")
classifier.save("model")

This creates the model directory:

model
├── config.json
├── pytorch_model.bin
├── special_tokens_map.json
├── tokenizer_config.json
├── tokenizer.json
└── vocab.txt

And I'm loading it as follows:

classifier = HappyTextClassification(load_path="model")

Which throws this error:

Traceback (most recent call last):
  File "example.py", line 18, in <module>
    classifier = HappyTextClassification(load_path="model")
  File "/home/akshat/Projects/discord-bot/.venv/lib/python3.8/site-packages/happytransformer/happy_text_classification.py", line 34, in __init__
    model = AutoModelForSequenceClassification.from_pretrained(load_path, config=config)
  File "/home/akshat/Projects/discord-bot/.venv/lib/python3.8/site-packages/transformers/models/auto/auto_factory.py", line 441, in from_pretrained
    return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
  File "/home/akshat/Projects/discord-bot/.venv/lib/python3.8/site-packages/transformers/modeling_utils.py", line 1457, in from_pretrained
    model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_state_dict_into_model(
  File "/home/akshat/Projects/discord-bot/.venv/lib/python3.8/site-packages/transformers/modeling_utils.py", line 1602, in _load_state_dict_into_model
    raise ValueError(
ValueError: The state dictionary of the model you are training to load is corrupted. Are you sure it was properly saved?

Any guidance would be helpful!

EricFillion commented 2 years ago

Thanks for pointing this out. Please provide the model name and type to the HappyTextClassification class when loading a model. So, for @akshatshah21 case, you'll use the code.

happy_tc = HappyTextClassification("BERT", "Hate-speech-CNERG/dehatebert-mono-english")

happy_tc.save("model/")

happy_tc = HappyTextClassification("BERT", "Hate-speech-CNERG/dehatebert-mono-english", load_path="model/")

I'll update the documentation shortly to include this info.

akshatshah21 commented 2 years ago

Yes, this works. Thank you!