huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
135.02k stars 27.02k forks source link

Issue: BART does not learn during fine-tuning for abstractive text summarization #11870

Closed DidiDerDenker closed 3 years ago

DidiDerDenker commented 3 years ago

Environment info

Who can help

Information

I am currently working on abstractive text summarization. In the process I am trying to fine-tune BART on german text data. This works i.e. with bert-base-multilingual-cased and bert-base-german-cased. This does not work with i.e. deepset/gbert-base, deepset/gelectra-large and mbart-large-cc25. The training is not making any progress. The loss converges to zero very quickly. Am I doing something wrong? Do I need to use other classes?

To reproduce

Here are a few code snippets to reproduce this behavior:

# Config
language = "german"
model_name = "facebook/mbart-large-cc25"
tokenizer_name = "facebook/mbart-large-cc25"
batch_size = 8

# Imports
import datasets
import transformers
import tf2tf_tud_gpu_config as config
import tf2tf_tud_gpu_helpers as helpers

# Main
tokenizer = transformers.AutoTokenizer.from_pretrained(
    config.tokenizer_name, strip_accent=False
)

if "mbart" in config.model_name:
  tf2tf = transformers.MBartForConditionalGeneration.from_pretrained(
      config.model_name
  )

  else:
  tf2tf = transformers.EncoderDecoderModel.from_encoder_decoder_pretrained(
      config.model_name, config.model_name, tie_encoder_decoder=True
  )

train_data, val_data, test_data = helpers.load_data(
    language=config.language,
    ratio_corpus_wiki=config.ratio_corpus_wiki,
    ratio_corpus_news=config.ratio_corpus_news
)

if "mbart" in config.model_name:
    training_args = transformers.TrainingArguments(
        output_dir=config.path_output,
        logging_dir=config.path_output,
        per_device_train_batch_size=1,
        per_device_eval_batch_size=1,
        num_train_epochs=1,
        warmup_steps=500,
        weight_decay=0.01
    )

    trainer = transformers.Trainer(
        model=tf2tf,
        args=training_args,
        train_dataset=train_data,
        eval_dataset=val_data
    )

else:
    training_args = transformers.Seq2SeqTrainingArguments(
        predict_with_generate=True,
        evaluation_strategy="steps",
        per_device_train_batch_size=config.batch_size,
        per_device_eval_batch_size=config.batch_size,
        output_dir=config.path_output,
        warmup_steps=1000,
        save_steps=10000,
        logging_steps=2000,
        eval_steps=10000,
        save_total_limit=1,
        learning_rate=5e-5,
        adafactor=True,
        fp16=True
    )

    trainer = transformers.Seq2SeqTrainer(
        model=tf2tf,
        args=training_args,
        compute_metrics=compute_metrics,
        train_dataset=train_data,
        eval_dataset=val_data,
        tokenizer=tokenizer
    )

trainer.train()

Expected behaviour

I would like to fine-tune BART profitably.

sgugger commented 3 years ago

cc @patrickvonplaten

github-actions[bot] commented 3 years ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

DidiDerDenker commented 3 years ago

@patrickvonplaten Hi, unfortunately I have not been able to make any progress in the last month and would appreciate if you have a solution for the unexpected behavior. Thank you! :)

github-actions[bot] commented 3 years ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

patrickvonplaten commented 3 years ago

Hey @DidiDerDenker,

Sorry it's very difficult to debug customized training runs that don't produce good results for us. Could you instead try to use the forum: https://discuss.huggingface.co