huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
133.17k stars 26.59k forks source link

mBART is not saving (learned) position embeddings #9525

Closed juand-r closed 3 years ago

juand-r commented 3 years ago

Environment info

Who can help

@patrickvonplaten

Information

I am fine-tuning mBART-large on MLSUM (Spanish, and also Russian). However, I noticed two things:

I noticed that the mBART config includes:

keys_to_never_save = [
        "model.encoder.embed_positions.weight",
        "model.decoder.embed_positions.weight",
    ]

and likewise for keys_to_ignore_on_load_missing. I suppose this was done in response to issue #7296. This would be fine if the mBART position embeddings were static, but they seem to be learned. The mbart configuration shows static_position_embeddings = False.

I can load and save the mBART model correctly if I set the following before fine-tuning:

mbart_model._keys_to_ignore_on_load_missing = None
mbart_model._keys_to_ignore_on_save = None

The problem arises when using:

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

mbart_tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25")
mbart_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/mbart-large-cc25")

The tasks I am working on is:

Abstractive summarization.

To reproduce

Steps to reproduce the behavior:

  1. Load the model: mbart_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/mbart-large-cc25")
  2. Fine-tune the mBART model and use load_best_model_at_end=True.
  3. Save and load the fine-tuned model, and verify that they are different (and texts generated from them are different).
  4. Setting mbart_model._keys_to_ignore_on_load_missing = None and mbart_model._keys_to_ignore_on_save = None fixes the problem (the full model is saved, and the checkpoints are correct).

Expected behavior

The model's position embeddings and generated outputs should be exactly the same after saving it and loading from disk.

patrickvonplaten commented 3 years ago

Hey @juand-r,

Thanks for the issue! I think this problem should be solved by now. We have done some major refactoring for MBart and removed the _keys_to_ignore_on_save for MBart. Can you check whether the error persists on current master? We will do a release tomorrow probably so that the fix should be included in the next pip version :-)

juand-r commented 3 years ago

Thanks, @patrickvonplaten !

I just checked the error is gone when using version 4.2.1.

ozcangundes commented 3 years ago

Hey @juand-r ,

I am also trying to fine tune mBART for some non English corpus. Is there any sample script that I can follow for this task?

juand-r commented 3 years ago

Hi @ozcangundes,

This could be helpful: https://github.com/GEM-benchmark/GEM-baseline-models/blob/main/examples/mbart_large_mlsum_ru.ipynb

Hey @juand-r ,

I am also trying to fine tune mBART for some non English corpus. Is there any sample script that I can follow for this task?