Open jon-chuang opened 1 year ago
Accumulation is done in fp32 for numerical precision and numerical stability. Fp16 might get inf if e.g. Q @ K^T has entries larger than 65k.
Do you think it may be feasible to use fp16
accumulation for P @ V
?
We know that:
exp(k - amax) <= 1.0
P
is normalized to 1. In particular sum(P_j, axis=1) = 1
for all j
Notes:
Enforce numerical stability for non-inf values of V: max(O_i) <= max(V_i) <= max(fp16)
:
If we choose a factor stability_factor = log(seq_len)
:
Then we know that prior to reciprocal:
max(O_ij) = max(sum_{k<j} (exp(S-amax-stability_factor) @ V_ij)) < 1/seq_len * (j + 1) * block_k * max(V_i0, ..., V_ij) <= max(V_i0, ..., V_ij) <= max(V)
So we simply need to exponentiate: S - amax - stability_factor
. Of course, we may get poor numerical accuracy by squashing towards to 0.
Due to uncertain numerical accuracy for delayed reciprocal, it seems wise to apply the fp16 accumulation only to the in-loop reciprocal case.
I believe that one should benchmark both cases (fp16 acc + in-loop
, fp32 acc + delayed reciprocal
).
Interesting idea! There are 2 related things here:
I'm not sure how much faster fp16 accumulate could be. Yeah we should benchmark. Maybe it's easier to do that with the Triton implementation?
Benchmarks here.
Summary:
Have yet to explore numerical accuracy (I think in the first place numerical accuracy may be lower for Triton/Pallas)
For consumer cards (e.g., 4090) matmul with fp16 accum has twice the throughput of matmul with fp32 accum. Maybe that explains some of the speed difference you're seeing. However, for data center cards (e.g., A100) I think the throughput is the same. The benefit there is probably fewer registers needed to hold the accumulators.
Yes, interestingly enough and rather sadly:
Ada Lovelace page 30:
A100/H100 page 39:
In terms of microbenchmarks at the instruction level (no data movement) the real world throughput is only 5% more for fp16 accumulation on A100.
It would be nice to benchmark the full kernel on A100/H100.
Given that I measured high register pressure and register spilling in the flash_attn kernel compared with Triton kernel, flash_attn kernel may stand more to gain.
The fact that matmul on V100 with FP16 accumulation is competitive or faster (for larger K) is at least an encouraging sign.
Do you think that Lambda Lab's H100 on-demand cloud would be a good idea? However, replicating my local install environment into a container to deploy quickly would be a pain.
After some further thought, I think for H100, it is worth investigating using more FP16 arithmetic.
That's because FP32 arithmetic on H100 is unbelievably expensive compared with FP16 Tensor Core TFLOPs (16.7x).
FP32 TFLOPs is lower than RTX 4090 (60 v.s. 82.6) in contrast with FP16 TFLOPs (120 v.s. 82.6).
I don't think tensor core is the bottleneck for H100 perf.
While it might be that memory bandwidth is bottleneck, I believe that investigating using FP16 arithmetic - or otherwise further reducing non-matmul arithmetic - in the main loop will be extremely productive.
An interesting datapoint from Adept.ai Persimmon 8B (released Sept 7 2023):
We add layernorm to the Q and K embeddings before they enter the attention calculation.
This should help bounding Q @ K
float16 magnitude. So I believe that even Q @ K
fp16 accumulation stability can be feasible with the right model choice.
A100/H100 page 39:
There is a new 1.02 version at https://www.hpctech.co.jp/catalog/gtc22-whitepaper-hopper_v1.02.pdf. Not sure if anything significant changed. But it is sad Nvidia whitepaper is referencing A100 40GB (1.5TB/s bandwidth) vs H100 80GB instead of apples-to-apples A100 80GB (2TB/s bandwidth).
As we can see, it is hardcoded to float here: https://github.com/Dao-AILab/flash-attention/blob/d0032700d1c7a1353a3a8f2fadfbc73b2dc6b5dc/csrc/flash_attn/src/kernel_traits.h#L26
However, it is known that fp16 accumulation is possibly faster for matmuls (see e.g. PyTorch Benchmarks).