Dao-AILab / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
14.17k stars 1.32k forks source link

Support quantized KV cache in LLM inference? #797

Open zhaoyang-star opened 9 months ago

zhaoyang-star commented 9 months ago

Quantizing KV cache in LLM inference is a common method to boost performance. I noticed that FA has supported paged kv cache. Should we support fp8 or int8 kv cache?

tridao commented 9 months ago

Makes sense, if there's time.

zhaoyang-star commented 9 months ago

@tridao Thanks for your quick respose. I noticed that paged kv cache kernel in FA has lower latency than vLLM's paged attention kernel (hand-written). So I want to use FA in vLLM. The main question is FA only supports half and bfloat16 kv cache data type. So supporting fp8 cache data type is needed.

The following benchmark is tested on A100-40GB, with num_query_heads=56, num_kv_heads=8, and dtype=bfloat16

Batchsize Paged attention kernel in vLLM (us) FA (us) Speedup
1 61.142 67.126 0.91
4 256.677 75.528 3.40
16 552.368 264.750 2.09
64 2504.497 818.838 3.06
256 7878.213 2550.177 3.09

I am not familar with FA's code. But I think there is only limited lines should be changed to support fp8 cache without scale factor, because there is no additional quantization params. See https://github.com/vllm-project/vllm/blob/main/csrc/attention/attention_kernels.cu#L218 for more details.

tridao commented 9 months ago

Great, thanks for the benchmark. The number for FA batch size 16 seems wrong? I don't have much bandwidth these days. 8 bit KV cache will be there at some point, I just can't say when.

zhaoyang-star commented 9 months ago

@tridao The data for FA batch size 16 has updated.

I know @beginlner is the co-author of paged kv cache feature. Are you interested in it?