huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
131.25k stars 26.09k forks source link

Static cache + torch.compile: better documentation for prefill static sequence length #29151

Closed fxmarty closed 1 month ago

fxmarty commented 6 months ago

Feature request

When using torch.compile, the prefill is recompiled for every new sequence length, which is slow. It may be nice to be able to compile only say for some sequence lengths (1, 2, 4, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, etc) on the fly depending on the input lengths, using some padding.

Motivation

torch.compile compilation is prohibitively slow even with https://github.com/huggingface/transformers/pull/29114

If people want to use transformers + static cache + torch.compile, it should be FAST to run generate on new sequence lengths.

Your contribution

None for now

amyeroberts commented 6 months ago

cc @gante

gante commented 6 months ago

@fxmarty this is the same problem as we have in TF and Flax. There, we nudged users to use the pad_to_multiple_of argument in the tokenizer, which I believe solves the problem 🤗

How do you suggest us to let users know about this feature, other than docs?

fxmarty commented 6 months ago

@gante That's already good to support that in the tokenizer, but I am wondering whether it would make sense to support that in the generation directly. Have you seen any user request about that?

gante commented 6 months ago

@fxmarty I haven't.

I am also not a big fan of it: a) it pushes the problem from forward to generate (i.e. forward would not see recompilations, but generate will, as it will have an input tensor with arbitrary length) b) it hides the real behavior (padding) from the user, which may lead to issues due to behavior misunderstandings. An obvious one I can foresee is "my input has X length, I have set max_new_tokens=Y, why isn't the output length X+Y?"

pad_to_multiple_of avoids the problems I mentioned, but it is harder to discover 🤗 Still, I think it is preferable!

fxmarty commented 6 months ago

a) it pushes the problem from forward to generate (i.e. forward would not see recompilations, but generate will, as it will have an input tensor with arbitrary length)

Not really (at least not for torch.compile), as generate is simply not compiled.

b) it hides the real behavior (padding) from the user, which may lead to issues due to behavior misunderstandings. An obvious one I can foresee is "my input has X length, I have set max_new_tokens=Y, why isn't the output length X+Y?"

Fair enough. I think a warning could be shown in generate (e.g. in case the model is an OptimizedModule) about the feature and/or we could document the usage with torch.compile.

gante commented 6 months ago

as generate is simply not compiled.

@fxmarty yet ;) Beam search has some heavy tensor operations that should be compiled, some logits processors are heavy, etc.

The difference between passing a flag to generate or to the tokenizer is small, but passing to generate will restrict our ability to fully compile generate if we decide to go through that path for some reason

fxmarty commented 6 months ago

@gante agreed although @torch.compiler.disable is useful for that

gante commented 3 months ago

https://github.com/huggingface/transformers/pull/30788 -- this PR adds documentation to use pad_to_multiple_of to avoid input shape-related recompilation

I'm assuming this issue can be closed after the PR gets merged :) In the generate refactor we will be separating the prefill step, and we can then move/enhance related documentation.