Dao-AILab / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
14.1k stars 1.31k forks source link

Any plan for support paged attention? #660

Closed donglinz closed 9 months ago

donglinz commented 12 months ago

First of all, thank you for the great work!

Is there any plan to support paged kv cache in non-contiguous memory? For instance, in flash_attn_with_kvcache?

tridao commented 12 months ago

It's not the highest priority at the moment. Does the implementation from vLLM not work well?

donglinz commented 12 months ago

It's not the highest priority at the moment. Does the implementation from vLLM not work well?

Their scope is slightly different from flash_attn_with_kvcache, which only supports decoding (one token for each batch) I suppose. In many scenarios like speculative decoding. flash_attn_with_kvcache is preferable as it can compute multiple tokens for each batch in parallel.

tridao commented 12 months ago

Is vLLM planning to implement a version that can support more than 1 token?

tridao commented 12 months ago

Does it make sense to have paged KV cache as a standalone function without all the cache management kernels (in vLLM)? How would one use paged KV cache without a cache manager to copy / update the pages?

donglinz commented 12 months ago

Is vLLM planning to implement a version that can support more than 1 token?

I have no information on that but I can ask them under the vllm repo. https://github.com/vllm-project/vllm/issues/1598 for reference.

donglinz commented 12 months ago

Does it make sense to have paged KV cache as a standalone function without all the cache management kernels (in vLLM)? How would one use paged KV cache without a cache manager to copy / update the pages?

Yes, it is cache manager dependent as flash attention and vllm using different kv cache formats ([B,L,H,D] vs [n_blocks, H, D//x, block_size, x]).

But I think it should be fine as the biggest obstacle on my side is I cannot find a set of kernels that support both paged prefill and paged decode. The cache manager is not a big issue for me because it can be implemented in ~100 lines of python code. (As a user) as long as I have the kernels, I would gladly implement a cache manager by myself that fits the kernel format.

tridao commented 12 months ago

I'm not sure I understand what paged prefill mean, can you say more? During prefill, the KV cache are calculated as the output of the (nn.Linear) K_proj and V_proj. This is a contiguous memory blob. This contiguous memory blob then can be use for attention during prefill as usual (e.g. calling flash_attn). I assume vLLM would then copy this contiguous memory blob to different blocks in preparation for decoding?

tdene commented 11 months ago

@donglinz did you ever find a solution? I noticed that you closed vllm-project/vllm#1598.

zhaoyang-star commented 10 months ago

For the prefill, no cache will be used. I just replaced xformers with FA as xformers does not support MQA/GQA and found *attention caculation (softmax(Q @ K^T softmax_scale) @ V) latency is reduced 2+ times.** More details can be found on https://github.com/vllm-project/vllm/issues/1880

For decode stage, we should either rewrite the paged attention kernel in vllm, or modify the FlashAttention kernel to support paged KV cache. I have not evaluation the workload yet.

tridao commented 9 months ago

flash-attn now supports paged KV cache as of v2.5.0

curry-zzl commented 7 months ago

flash-attn now supports paged KV cache as of v2.5.0 @tridao I still wonder How would one use paged KV cache without a cache manager to copy / update the pages?

tridao commented 7 months ago

You'd need to implement your own cache manager

tspeterkim commented 4 months ago

For those who are interested, here's a simple cache manager: https://github.com/tspeterkim/paged-attention-minimal/