The advantage of multi-query attention (MQA) lies in both reducing the size of the KV cache and making self-attention computation more efficient. The current implementation only saves on KV cache size.
This PR improves it further by not only reducing the computation cost, but also saving the per-layer KV cache memory.
This becomes especially critical when dealing with very long contexts. For
instance, if an LLM is processing a context length of 1 million tokens using
the Character.ai architecture [1], there might be around 4 unique KV cache layers.
Let’s assume there are 4 KV heads and 32 total attention heads, with a
dim_per_head of 128. In the current implementation, each layer consumes
significant memory for self-attention KV caching (using bfloat16):
Benchmark results: it saves memory and computation.
tools/attention_benchmark.py on TPUv5p
ASIS
-----------------------------------------------------------------------------------------
Benchmark Time CPU Iterations HBM (over 95.74G)
-----------------------------------------------------------------------------------------
MQABenchmark/2048/16/2/1024 1.42 ms 0.247 ms 2347 291.16M
MQABenchmark/4096/16/2/1024 3.60 ms 0.277 ms 1257 322.95M
MQABenchmark/4096/16/2/4096 47.3 ms 0.818 ms 139 4.25G
MQABenchmark/4096/16/2/8192 869 ms 0.932 ms 140 48.00G
This PR
-----------------------------------------------------------------------------------------
Benchmark Time CPU Iterations HBM (over 95.74G)
-----------------------------------------------------------------------------------------
MQABenchmark/2048/16/2/1024 1.16 ms 0.256 ms 2535 262.35M
MQABenchmark/4096/16/2/1024 3.46 ms 0.294 ms 1114 266.88M
MQABenchmark/4096/16/2/4096 24.8 ms 0.769 ms 137 4.04G
MQABenchmark/4096/16/2/8192 860 ms 1.19 ms 136 48.00G
The advantage of multi-query attention (MQA) lies in both reducing the size of the KV cache and making self-attention computation more efficient. The current implementation only saves on KV cache size.
This PR improves it further by not only reducing the computation cost, but also saving the per-layer KV cache memory.
This becomes especially critical when dealing with very long contexts. For instance, if an LLM is processing a context length of 1 million tokens using the Character.ai architecture [1], there might be around 4 unique KV cache layers. Let’s assume there are 4 KV heads and 32 total attention heads, with a dim_per_head of 128. In the current implementation, each layer consumes significant memory for self-attention KV caching (using bfloat16):
[1] https://research.character.ai/optimizing-inference/
ASIS
This PR