huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
133.48k stars 26.66k forks source link

Automatic dynamic batch size selection for DataCollatorWithFlattening #33945

Open alex-hh opened 1 week ago

alex-hh commented 1 week ago

Feature request

Add a custom (batch index) sampler to automatically determine batch size to a fixed target number of tokens.

Motivation

I'm keen to try out DataCollatorWithFlattening but unsure about how to set batch size, since no padding will be added so the total number of tokens is dynamic.

Im also uncertain whether fixing the total number of tokens is itself optimal...Does optimal memory allocation require accounting for the amount of attention masking that will be applied to the batch?

Is there any recommendation on how to handle this currently?

(Edit: seems like near-optimal solution for map-style datasets is provided by https://github.com/imoneoi/multipack_sampler/tree/master, which presumably just tries to ensure all batches are as full as possible given some max number of tokens. It would be nice to support similar functionality for Iterable Datasets - not optimal packing, but adjusting batch size to adapt to number of tokens in examples should be possible)

Your contribution

May be able to try to implement something for iterable datasets if this is possible.

snow-kartikbagalore commented 1 week ago

I want to second this. When I try to use this collator, each of my batch (now a single 'list') is of varying lengths (as expected). I am unable to decide on a batch_size value, but I want to keep the total length of the flattened output less than some large number, say 8192. Is it possible to have this behaviour today?

ArthurZucker commented 1 week ago

Not really an expert on this, so I am not sure I fully understand the motivation, flattening AFAIK should be more "efficient" and don't really see a reason to limit the length of the flattened output. But if you need to pack your inputs, instead of ragging them, I think there is another data collator, which creates the appropriate batches based on the max length you have.

alex-hh commented 1 week ago

The reason for the max length would be to prevent out-of-memory issues (e.g. idea would be you know that your model can process some total number of tokens unmasked without oom, then packed batches must be selected to stay below this limit - and since you're not using padding this might involve a variable number of examples per packed batch).

What's the other collator? Sounds useful