This PR adds Flash Attention for ARM_NEON. The Zen4/AVX2 implementation is reused with a few platform specific additions for ARM_NEON. As with AVX2, it is just for fp16 kv-cache for now.
On ARM_NEONfp16 arithmetic is used to compute K*Q (unlike Zen4/AVX2, which use fp32). Initially I was also using fp16 to operate on the K*Q product (the soft_max related stuff), and that worked fine for the models I was using for testing (Gemma2-2b, TriLM-4B). But fp16 fails for LLaMA-3.1-8B, so I had to change for fp321.
Performance gains are not as good as Zen4/AVX2. My guess is that due to the significantly higher memory bandwidth of the M2 Max used for testing the ARM_NEON implementation (compared to the Zen4/AVX2 systems I have available), the penalty of not having intermediate results in the cache when computing KQV is less. Nevertheless, for LLaMA-3.1-8B at a context of 2k tokens, using FA is about 4% faster than not using FA on the M2 Max. In contrast, the mainline llama.cpp FA implementation is ~17% slower than no-FA.
1 I must admit I don't really understand why because expf (and tanh when soft-capping is involved) are computed in fp32 even when K*Q is fp16, so possibly there was a bug that I was not able to find in the fp32 <-> fp16 conversions rather than a loss of precision.
This PR adds Flash Attention for
ARM_NEON
. TheZen4/AVX2
implementation is reused with a few platform specific additions forARM_NEON
. As withAVX2
, it is just forfp16
kv-cache for now.On
ARM_NEON
fp16
arithmetic is used to computeK*Q
(unlikeZen4/AVX2
, which usefp32
). Initially I was also usingfp16
to operate on theK*Q
product (thesoft_max
related stuff), and that worked fine for the models I was using for testing (Gemma2-2b, TriLM-4B). Butfp16
fails for LLaMA-3.1-8B, so I had to change forfp32
1.Performance gains are not as good as
Zen4/AVX2
. My guess is that due to the significantly higher memory bandwidth of the M2 Max used for testing theARM_NEON
implementation (compared to theZen4/AVX2
systems I have available), the penalty of not having intermediate results in the cache when computingKQV
is less. Nevertheless, for LLaMA-3.1-8B at a context of 2k tokens, using FA is about 4% faster than not using FA on the M2 Max. In contrast, the mainlinellama.cpp
FA implementation is ~17% slower than no-FA.1 I must admit I don't really understand why because
expf
(andtanh
when soft-capping is involved) are computed infp32
even whenK*Q
isfp16
, so possibly there was a bug that I was not able to find in thefp32 <-> fp16
conversions rather than a loss of precision.