huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
135.65k stars 27.16k forks source link

Support dynamic batch size #34406

Open fzyzcjy opened 1 month ago

fzyzcjy commented 1 month ago

Feature request

Hi thanks for the library! When training, I realize that, if a micro batch contains too few tokens, the throughput will be quite bad (i.e. average time per token is large). However, I cannot increase the batch size, because there are long (e.g. 2000 tokens) and short (e.g. 500 tokens) sequences in the training data. The batch size that make short sequences run fast will make long sequences OOM.

Therefore, I am proposing to have dynamic (micro) batch size. For example, suppose we have batch_size=16. Then, before this proposal, we have e.g. micro_batch_size=2 & grad_accum=8. After this proposal, for short sequences, use 4 samples in this micro batch; for long sequences, use 2 samples in this micro batch. After they sum up to 16 samples, we can compute the loss and consider this step is done.

Motivation

(see above)

Your contribution

I am happy to PR

LysandreJik commented 1 month ago

cc @muellerzr WDYT?