huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
133.89k stars 26.78k forks source link

['encoder.version', 'decoder.version'] are unexpected when loading a pretrained BART model #6652

Closed stas00 closed 4 years ago

stas00 commented 4 years ago

Using an example from the bart doc: https://huggingface.co/transformers/model_doc/bart.html#bartforconditionalgeneration

from transformers import BartTokenizer, BartForConditionalGeneration
tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')
TXT = "My friends are <mask> but they eat too many carbs."

model = BartForConditionalGeneration.from_pretrained('facebook/bart-large')
input_ids = tokenizer([TXT], return_tensors='pt')['input_ids']
logits = model(input_ids)[0]

masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
probs = logits[0, masked_index].softmax(dim=0)
values, predictions = probs.topk(5)

print(tokenizer.decode(predictions).split())

gives:

Some weights of the model checkpoint at facebook/bart-large were not used 
when initializing BartForConditionalGeneration: 
['encoder.version', 'decoder.version']

- This IS expected if you are initializing BartForConditionalGeneration from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing BartForConditionalGeneration from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
test:9: UserWarning: This overload of nonzero is deprecated:
        nonzero()
Consider using one of the following signatures instead:
        nonzero(*, bool as_tuple) (Triggered internally at  /opt/conda/conda-bld/pytorch_1597302504919/work/torch/csrc/utils/python_arg_parser.cpp:864.)
  masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
['good', 'great', 'all', 'really', 'very']

well, there is one more issue of using a weird deprecated nonzero() invocation, which has to do with some strange undocumented requirement to pass the as_tuple arg, since pytorch 1.5 .https://github.com/pytorch/pytorch/issues/43425

we have authorized_missing_keys: authorized_missing_keys = [r"final_logits_bias", r"encoder\.version", r"decoder\.version"] https://github.com/huggingface/transformers/blob/master/src/transformers/modeling_bart.py#L942 which correctly updates missing_keys - should there be also an authorized_unexpected_keys which would clean up unexpected_keys?

(note: I re-edited this issue once I understood it better to save reader's time, the history is there if someone needs it)

And found another variety of it: for ['model.encoder.version', 'model.decoder.version']

tests/test_modeling_bart.py::BartModelIntegrationTests::test_mnli_inference Some weights of the model checkpoint at facebook/bart-large-mnli were not used when initializing BartForSequenceClassification: ['model.encoder.version', 'model.decoder.version']
- This IS expected if you are initializing BartForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing BartForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
PASSED
sshleifer commented 4 years ago

Yeah I think the clean solution is authorized_extra_keys but I could also just reconvert the models. We could also leave the warning. What do you think @sgugger ?

stas00 commented 4 years ago

IMHO, that warning makes the library look somewhat amateurish, as it makes the user wonder whether something is wrong, for absolutely no reason.

As I'm the one who is bothered - If I can be of help resolving this please don't hesitate to delegate this to me.

sgugger commented 4 years ago

The cleanest would be to reconvert the models and remove the keys we don't need, I think. Adding the authorized_extra_keys works too, but then using it too much could have unexpected consequences resulting in bugs, so I'd only go down that road if there is no other option.

LysandreJik commented 4 years ago

The simplest and cleanest way would probably to simply remove these two variables from the state dict, wouldn't it? If reconverting the checkpoint you should check that it is exactly the same as the previous one, which sounds like more of a pain and more error prone than simply doing

!wget https://cdn.huggingface.co/facebook/bart-large/pytorch_model.bin

weights = torch.load('/path/to/pytorch_model.bin')
del weights['encoder.version']
del weights['decoder.version']
torch.save(weights, 'new_pytorch_model.bin')
sshleifer commented 4 years ago

Done. Also converted weights to fp16.