huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
131.71k stars 26.22k forks source link

KV cache optimization with paged attention #27303

Open liangan1 opened 10 months ago

liangan1 commented 10 months ago

Feature request

Paged attention has been enabled by a lot of server engine, e.g., vllm, tensorrt-llm

Motivation

KV cache is used to reduce computation for Decoder layer but it also bring memory overheads, for example, when we use beam search, the kv_cache should be reordered according to latest beam idx and the current key/value should also be concat with kv_cache in the attention layer to get entire context to do scale dot product. When the sequence is very long, the memory overhead will be performance bottleneck.

Your contribution

No PR yet

amyeroberts commented 10 months ago

cc @gante (I think this is closest to your work - sorry if wrong! )

liangan1 commented 10 months ago

@jgong5

gante commented 10 months ago

Hi @liangan1 👋

We are close to introducing a new cache abstraction (https://github.com/huggingface/transformers/pull/26681). I believe that, after this PR is merged, adding paged attention would become directly applicable on top of it :)

Would you be interested in adding it to transformers?

liangan1 commented 10 months ago

Hi @liangan1 👋

We are close to introducing a new cache abstraction (#26681). I believe that, after this PR is merged, adding paged attention would become directly applicable on top of it :)

Would you be interested in adding it to transformers?

Sure. We are pleasure to contribute more kv_cache related optimizations.

gante commented 10 months ago

Awesome, I will let you know when the cache abstraction is ready!

liangan1 commented 10 months ago

Thanks.

gante commented 9 months ago

@liangan1 the cache abstraction will be merged today, so you can start working on top of it. Happy to provide pointers and suggestions! 🙌

github-actions[bot] commented 8 months ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

NicolasMejiaPetit commented 2 months ago

As of the latest release if flash attention v2.5 paged kv cache is now supported. https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#25-paged-kv-cache. This being implemented into transformers would be pretty awesome, specially when it can stack with quantized kv cache, allowing for more than 100,00k tokens on consumer gpu’s, if you have 64gb of shared memory then like 500,000 tokens of context, on a 7b 4bit model.