This PR adds a flash attention (FA) implementation optimized for the Zen4 architecture as part of the quest to improve CPU inference for long contexts (#25, #26).
Limitations
It is Zen4-only for now. Strictly speaking, a much smaller subset of the AVX512 specification is required in the implementation (just AVX512F and AVX512DQ) compared to what Zen4 provides, but I didn't want to have too many variants, so decided to enable for Zen4 only.
It is not implemented for ALiBi or unmasked attention. It is trivial to add these but I didn't want to clutter the implementation with branches that are mostly irrelevant.
Performance comparisons
The following graph compares the prompt processing (PP) performance of mainline llama.cpp (build: a47667cf - 3650) without (green symbols) and with (blue symbols) FA to PP performance in this repository for Q4_K_S-quantized LLaMA-3.1-8B running on a Ryzen-7950X CPU where
Black symbols are without FA
Brown symbols are with FA inherited from llama.cpp
Magenta symbols are with the new FA implementation in this PR
We observe that the original FA implementation results in a significant performance degradation in mainline llama.cpp and also here. The effect is much stronger for the version here. This is due to the K*Q and V*(softmax(K*Q) matrix multiplications being much faster in this repository thanks to iqk_mul_mat, so performance hit is larger when they are replaced with the original llama.cpp FA CPU kernel. The new FA implementation improves performance. The improvement increases with context length, reaching about 24% at 32k tokens.
The next graph shows results for Q4_K_S-quantized Gemma-2-2b. Symbol colors are the same as above.
In this case the original FA kernel improves performance in mainline llama.cpp. The difference in behavior compared to LLaMA-3.1-8B is easily explained by the fact that the Gemma-2 series of models use "soft-caping" in their attention layers, where softcap(x) = c * tanh(x/c) (c is a model-defined constant). This is implemented as 3 different operations in llama.cpp. When FA is enabled, these 3 operations, along with softmax are fused into a single kernel, and this results in am improvement of mainline llama.cpp performance even for short contexts. But when the original FA kernel is used in our version, where "soft-caping" is already handled by a dedicated fused operation, we get a massive drop in performance just like in the LLaMA-3.1-8B case above. The new implementation in this PR is much better and performance improves again, reaching 11% at 8k tokens, which is the maximum training context length of Gemma-2-2b.
TL;DR
This PR adds a flash attention (FA) implementation optimized for the Zen4 architecture as part of the quest to improve CPU inference for long contexts (#25, #26).
Limitations
AVX512F
andAVX512DQ
) compared to what Zen4 provides, but I didn't want to have too many variants, so decided to enable for Zen4 only.Performance comparisons
The following graph compares the prompt processing (PP) performance of mainline
llama.cpp
(build: a47667cf - 3650) without (green symbols) and with (blue symbols) FA to PP performance in this repository forQ4_K_S
-quantized LLaMA-3.1-8B running on a Ryzen-7950X CPU wherellama.cpp
We observe that the original FA implementation results in a significant performance degradation in mainline
llama.cpp
and also here. The effect is much stronger for the version here. This is due to theK*Q
andV*(softmax(K*Q)
matrix multiplications being much faster in this repository thanks toiqk_mul_mat
, so performance hit is larger when they are replaced with the originalllama.cpp
FA CPU kernel. The new FA implementation improves performance. The improvement increases with context length, reaching about 24% at 32k tokens.The next graph shows results for
Q4_K_S
-quantized Gemma-2-2b. Symbol colors are the same as above.In this case the original FA kernel improves performance in mainline
llama.cpp
. The difference in behavior compared to LLaMA-3.1-8B is easily explained by the fact that the Gemma-2 series of models use "soft-caping" in their attention layers, wheresoftcap(x) = c * tanh(x/c)
(c
is a model-defined constant). This is implemented as 3 different operations inllama.cpp
. When FA is enabled, these 3 operations, along withsoftmax
are fused into a single kernel, and this results in am improvement of mainlinellama.cpp
performance even for short contexts. But when the original FA kernel is used in our version, where "soft-caping" is already handled by a dedicated fused operation, we get a massive drop in performance just like in the LLaMA-3.1-8B case above. The new implementation in this PR is much better and performance improves again, reaching 11% at 8k tokens, which is the maximum training context length of Gemma-2-2b.