Dao-AILab / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
13.7k stars 1.26k forks source link

perf: matmul accumulation does not need to default to float #543

Open jon-chuang opened 1 year ago

jon-chuang commented 1 year ago

As we can see, it is hardcoded to float here: https://github.com/Dao-AILab/flash-attention/blob/d0032700d1c7a1353a3a8f2fadfbc73b2dc6b5dc/csrc/flash_attn/src/kernel_traits.h#L26

However, it is known that fp16 accumulation is possibly faster for matmuls (see e.g. PyTorch Benchmarks).

tridao commented 1 year ago

Accumulation is done in fp32 for numerical precision and numerical stability. Fp16 might get inf if e.g. Q @ K^T has entries larger than 65k.

jon-chuang commented 1 year ago

Do you think it may be feasible to use fp16 accumulation for P @ V?

We know that:

  1. Delayed softmax reciprocal (FA2): exp(k - amax) <= 1.0
  2. In-loop reciprocal (FA1): P is normalized to 1. In particular sum(P_j, axis=1) = 1 for all j
jon-chuang commented 1 year ago

Notes:

Handling of delayed softmax reciprocal case:

Enforce numerical stability for non-inf values of V: max(O_i) <= max(V_i) <= max(fp16):

If we choose a factor stability_factor = log(seq_len):

Then we know that prior to reciprocal: max(O_ij) = max(sum_{k<j} (exp(S-amax-stability_factor) @ V_ij)) < 1/seq_len * (j + 1) * block_k * max(V_i0, ..., V_ij) <= max(V_i0, ..., V_ij) <= max(V)

So we simply need to exponentiate: S - amax - stability_factor. Of course, we may get poor numerical accuracy by squashing towards to 0.


Due to uncertain numerical accuracy for delayed reciprocal, it seems wise to apply the fp16 accumulation only to the in-loop reciprocal case.

I believe that one should benchmark both cases (fp16 acc + in-loop, fp32 acc + delayed reciprocal).

tridao commented 1 year ago

Interesting idea! There are 2 related things here:

I'm not sure how much faster fp16 accumulate could be. Yeah we should benchmark. Maybe it's easier to do that with the Triton implementation?

jon-chuang commented 1 year ago

Benchmarks here.

Summary:

  1. FP16 helps a lot
  2. Delayed softmax reciprocal seems to produce no measured difference in this setting (I previously measured a 5% difference in the fp32 accumulation setting. Hypothesis - we are even more memory bottlenecked than before. Solution - explore even more pipeline stages)

Have yet to explore numerical accuracy (I think in the first place numerical accuracy may be lower for Triton/Pallas)

tridao commented 1 year ago

For consumer cards (e.g., 4090) matmul with fp16 accum has twice the throughput of matmul with fp32 accum. Maybe that explains some of the speed difference you're seeing. However, for data center cards (e.g., A100) I think the throughput is the same. The benefit there is probably fewer registers needed to hold the accumulators.

jon-chuang commented 1 year ago

Yes, interestingly enough and rather sadly:

Ada Lovelace page 30: image

A100/H100 page 39: image

In terms of microbenchmarks at the instruction level (no data movement) the real world throughput is only 5% more for fp16 accumulation on A100.

It would be nice to benchmark the full kernel on A100/H100.

Given that I measured high register pressure and register spilling in the flash_attn kernel compared with Triton kernel, flash_attn kernel may stand more to gain.

The fact that matmul on V100 with FP16 accumulation is competitive or faster (for larger K) is at least an encouraging sign.


RE e2e benchmarking on H100:

Do you think that Lambda Lab's H100 on-demand cloud would be a good idea? However, replicating my local install environment into a container to deploy quickly would be a pain.

jon-chuang commented 1 year ago

After some further thought, I think for H100, it is worth investigating using more FP16 arithmetic.

That's because FP32 arithmetic on H100 is unbelievably expensive compared with FP16 Tensor Core TFLOPs (16.7x).

FP32 TFLOPs is lower than RTX 4090 (60 v.s. 82.6) in contrast with FP16 TFLOPs (120 v.s. 82.6).

I don't think tensor core is the bottleneck for H100 perf.

While it might be that memory bandwidth is bottleneck, I believe that investigating using FP16 arithmetic - or otherwise further reducing non-matmul arithmetic - in the main loop will be extremely productive.

image

jon-chuang commented 1 year ago

An interesting datapoint from Adept.ai Persimmon 8B (released Sept 7 2023):

We add layernorm to the Q and K embeddings before they enter the attention calculation.

This should help bounding Q @ K float16 magnitude. So I believe that even Q @ K fp16 accumulation stability can be feasible with the right model choice.

Qubitium commented 12 months ago

A100/H100 page 39: image

There is a new 1.02 version at https://www.hpctech.co.jp/catalog/gtc22-whitepaper-hopper_v1.02.pdf. Not sure if anything significant changed. But it is sad Nvidia whitepaper is referencing A100 40GB (1.5TB/s bandwidth) vs H100 80GB instead of apples-to-apples A100 80GB (2TB/s bandwidth).