Dao-AILab / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
13.1k stars 1.18k forks source link

Inference benchmarks of Flash Decoding (+confounding package changes) #866

Open ad8e opened 5 months ago

ad8e commented 5 months ago

We tried to test FD vs non-FD, but got lazy and just did a large container switch between 2.0.2 and 2.4.2. So there's also a newer Python and PyTorch in there. The speedup from (FD + Misc improvements) was not big.

Old container (non-FD, FA 2.0.2), prefill includes tokenization time, all numbers with warmup and repeated twice:

prompt length 2
tokens generated: 190
generated tokens/s: 22.767
prefill tokens/s: 45.565

prompt length 2060
tokens generated: 190
generated tokens/s: 20.726
prefill tokens/s: 3873.045

prompt length 8000
tokens generated: 190
generated tokens/s: 16.661
prefill tokens/s: 3814.175

prompt length 16000
tokens generated: 190
generated tokens/s: 13.244
prefill tokens/s: 3383.049
generated tokens/s: 13.152
prefill tokens/s: 3384.187

New container (FD, FA 2.4.2), prefill includes tokenization time, all numbers with warmup and repeated twice:

prompt length 2
tokens generated: 190
generated tokens/s: 22.961
prefill tokens/s: 45.574

prompt length 8000
tokens generated: 190
generated tokens/s: 18.654
prefill tokens/s: 3745.218

prompt length 16000
tokens generated: 190
generated tokens/s: 15.766
prefill tokens/s: 3306.542
generated tokens/s: 15.781
prefill tokens/s: 3313.040

Code in both containers is: attn_output = flash_attn_kvpacked_func(...)

Old container: http://ghcr.io/coreweave/ml-containers/torch-extras:es-flash-attn-2-5fbc6bb-base-cuda12.1.1-torch2.0.1-vision0.15.2-audio2.0.2-flash_attn2.0.2 New container: http://ghcr.io/coreweave/ml-containers/torch-extras:73a87a6-nccl-cuda12.2.2-ubuntu22.04-nccl2.19.3-1-torch2.2.0-vision0.17.0-audio2.2.0-flash_attn2.4.2 A40, 13B model.

tridao commented 5 months ago

Cool, thanks for the benchmarking. Flash Decoding speeds up the attn part during decoding, but attn may or may not take a large fraction of the time (depending on model size and seqlen).

tridao commented 5 months ago

One way to get a sense of what fraction of time attn is taking, is to comment out the line that does the attn during decoding (keep everything else the same, e.g, rotary etc) and see what generation speed you get. If attn only takes 20-30% then any optimization do you on attn will get you at most 1.3-1.4x speedup.

ad8e commented 5 months ago

Doing some naive math and assuming attn is free at ctx_len=2, Attn was 22.767/13.2 - 1 = 72% as expensive as the rest of the model under non-FD at 16k ctx length, and 22.961/15.766-1 = 46% as expensive as the rest of the model under FD at 16k ctx length. This means new attn takes 64% of the time of old attention.

I can try commenting out attention to isolate its changes from rotary.

ad8e commented 5 months ago

Attention commented out, rotary kept:

# attn_output = flash_attn_kvpacked_func(q=query, kv=key_value, causal=(query.shape[1] > 1))
attn_output = query

Results (with warmup):

prompt length 2
tokens generated: 190
generated tokens/s: 23.244
prefill tokens/s: 46.039

prompt length 16000
tokens generated: 190
generated tokens/s: 23.051
prefill tokens/s: 4285.551
generated tokens/s: 23.030
prefill tokens/s: 4214.007

So FD-attention in 2.4.2 is still a big chunk of the time spent. Difference of 15.7 tok/s -> 23 tok/s