ikawrakow / ik_llama.cpp

llama.cpp fork with additional SOTA quants and improved performance
MIT License
89 stars 6 forks source link

Zen4 Flash Attention #32

Closed ikawrakow closed 2 months ago

ikawrakow commented 2 months ago

TL;DR

This PR adds a flash attention (FA) implementation optimized for the Zen4 architecture as part of the quest to improve CPU inference for long contexts (#25, #26).

Limitations

Performance comparisons

The following graph compares the prompt processing (PP) performance of mainline llama.cpp (build: a47667cf - 3650) without (green symbols) and with (blue symbols) FA to PP performance in this repository for Q4_K_S-quantized LLaMA-3.1-8B running on a Ryzen-7950X CPU where

fa

We observe that the original FA implementation results in a significant performance degradation in mainline llama.cpp and also here. The effect is much stronger for the version here. This is due to the K*Q and V*(softmax(K*Q) matrix multiplications being much faster in this repository thanks to iqk_mul_mat, so performance hit is larger when they are replaced with the original llama.cpp FA CPU kernel. The new FA implementation improves performance. The improvement increases with context length, reaching about 24% at 32k tokens.

The next graph shows results for Q4_K_S-quantized Gemma-2-2b. Symbol colors are the same as above.

fa_gemma2b

In this case the original FA kernel improves performance in mainline llama.cpp. The difference in behavior compared to LLaMA-3.1-8B is easily explained by the fact that the Gemma-2 series of models use "soft-caping" in their attention layers, where softcap(x) = c * tanh(x/c) (c is a model-defined constant). This is implemented as 3 different operations in llama.cpp. When FA is enabled, these 3 operations, along with softmax are fused into a single kernel, and this results in am improvement of mainline llama.cpp performance even for short contexts. But when the original FA kernel is used in our version, where "soft-caping" is already handled by a dedicated fused operation, we get a massive drop in performance just like in the LLaMA-3.1-8B case above. The new implementation in this PR is much better and performance improves again, reaching 11% at 8k tokens, which is the maximum training context length of Gemma-2-2b.