apple / axlearn

An Extensible Deep Learning Library
Apache License 2.0
1.88k stars 269 forks source link

Optimize MQA computation. #837

Closed ds-hwang closed 1 week ago

ds-hwang commented 1 week ago

The advantage of multi-query attention (MQA) lies in both reducing the size of the KV cache and making self-attention computation more efficient. The current implementation only saves on KV cache size.

This PR improves it further by not only reducing the computation cost, but also saving the per-layer KV cache memory.

This becomes especially critical when dealing with very long contexts. For instance, if an LLM is processing a context length of 1 million tokens using the Character.ai architecture [1], there might be around 4 unique KV cache layers. Let’s assume there are 4 KV heads and 32 total attention heads, with a dim_per_head of 128. In the current implementation, each layer consumes significant memory for self-attention KV caching (using bfloat16):

[1] https://research.character.ai/optimizing-inference/

ASIS

-----------------------------------------------------------------------------------------
Benchmark                           Time             CPU   Iterations   HBM (over 95.74G)
-----------------------------------------------------------------------------------------
MQABenchmark/2048/16/2/1024       1.42 ms        0.247 ms         2347           291.16M
MQABenchmark/4096/16/2/1024       3.60 ms        0.277 ms         1257           322.95M
MQABenchmark/4096/16/2/4096       47.3 ms        0.818 ms          139             4.25G
MQABenchmark/4096/16/2/8192        869 ms        0.932 ms          140            48.00G

This PR

-----------------------------------------------------------------------------------------
Benchmark                           Time             CPU   Iterations   HBM (over 95.74G)
-----------------------------------------------------------------------------------------
MQABenchmark/2048/16/2/1024       1.16 ms        0.256 ms         2535           262.35M
MQABenchmark/4096/16/2/1024       3.46 ms        0.294 ms         1114           266.88M
MQABenchmark/4096/16/2/4096       24.8 ms        0.769 ms          137             4.04G
MQABenchmark/4096/16/2/8192        860 ms         1.19 ms          136            48.00G