huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
135.01k stars 27.01k forks source link

Dynamic batch size for Seq2SeqTrainer #10512

Closed clang88 closed 3 years ago

clang88 commented 3 years ago

🚀 Feature request

In Fairseq it is possible to forego setting a constant batch-size in favor of a dynamic batch size with --max_tokens. This ensures that a batch always consists of at max N=max_tokens tokens. Fairseq tries to get to max_tokens by adding samples to the batch until N = max_tokens or just below.

I believe @sshleifer has implemented this for finetune.py here: #7030

Is it possible to add "--max_tokens_per_batch N" as a trainer argument to Seq2SeqTrainer?

Motivation

This would an invaluable help when training/fine-tuning large models on data sequences (like sentences) of varying length. Long sequences/sentences might lead to OOM-Errors with a fixed batch-size.

LysandreJik commented 3 years ago

Pinging @patil-suraj and @sgugger

patil-suraj commented 3 years ago

Hi @clang88

The goal of the examples scripts is to keep them minimal and simple and I'm not sure if we want to support this immediately.

For now, you could use the --group_by_length argument which will group the long sequences together to avoid varying lengths and minimize the number of padding tokens.

Also to train large models, I would recommend you take a look at fairscale/deepspeed integration. Check this blog post for how to use fairscale/deepspeed with Trainer

@sgugger think we could add a MaxTokensSampler for this in case we want to support this.

clang88 commented 3 years ago

Hi @patil-suraj,

thank you for the quick reply!

I will take a look at the fairscale/deepspeed integration!

As for --group_by_length: this will only work correctly if I use trainer with a data_collator, am I correct? I have been already experimenting with that approach, but am having some trouble during the evaluation phase with custom_metrics. For whatever reason, the labels passed to the function by the trainer appear to be padded with the default of -100, even though I am passing label_pad_token_id= of 0 (for mT5) or 1 (for mBART) in the collator. I am aware this is a whole other issue, but maybe you are aware of any potential solutions for this?

That said, I am sure max_tokens_per_batch would a be a great asset, as group_by_length does not fix the underlying issue of having batches with very long sentences that go OOM. For now I am just truncating my dataset with max_length, but that clearly leads to less than ideal performance of the fine-tuned model.

sgugger commented 3 years ago

think we could add a MaxTokensSampler for this in case we want to support this

It's a whole batch sampler that would be needed, since it results in batch size not being constant. And it would need the same version as a distributed batch sampler. This is a lot of work for a very specific use case, so we could accept a PR on an example in a research project first.

Of course there is still the possibility of one user using the implementation from FAIR as was done in the old finetune script.

github-actions[bot] commented 3 years ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.