Open ad8e opened 9 months ago
Cool, thanks for the benchmarking. Flash Decoding speeds up the attn part during decoding, but attn may or may not take a large fraction of the time (depending on model size and seqlen).
One way to get a sense of what fraction of time attn is taking, is to comment out the line that does the attn during decoding (keep everything else the same, e.g, rotary etc) and see what generation speed you get. If attn only takes 20-30% then any optimization do you on attn will get you at most 1.3-1.4x speedup.
Doing some naive math and assuming attn is free at ctx_len=2, Attn was 22.767/13.2 - 1 = 72%
as expensive as the rest of the model under non-FD at 16k ctx length, and 22.961/15.766-1 = 46%
as expensive as the rest of the model under FD at 16k ctx length. This means new attn takes 64% of the time of old attention.
I can try commenting out attention to isolate its changes from rotary.
Attention commented out, rotary kept:
# attn_output = flash_attn_kvpacked_func(q=query, kv=key_value, causal=(query.shape[1] > 1))
attn_output = query
Results (with warmup):
prompt length 2
tokens generated: 190
generated tokens/s: 23.244
prefill tokens/s: 46.039
prompt length 16000
tokens generated: 190
generated tokens/s: 23.051
prefill tokens/s: 4285.551
generated tokens/s: 23.030
prefill tokens/s: 4214.007
So FD-attention in 2.4.2 is still a big chunk of the time spent. Difference of 15.7 tok/s -> 23 tok/s
We tried to test FD vs non-FD, but got lazy and just did a large container switch between 2.0.2 and 2.4.2. So there's also a newer Python and PyTorch in there. The speedup from (FD + Misc improvements) was not big.
Old container (non-FD, FA 2.0.2), prefill includes tokenization time, all numbers with warmup and repeated twice:
New container (FD, FA 2.4.2), prefill includes tokenization time, all numbers with warmup and repeated twice:
Code in both containers is:
attn_output = flash_attn_kvpacked_func(...)
Old container: http://ghcr.io/coreweave/ml-containers/torch-extras:es-flash-attn-2-5fbc6bb-base-cuda12.1.1-torch2.0.1-vision0.15.2-audio2.0.2-flash_attn2.0.2 New container: http://ghcr.io/coreweave/ml-containers/torch-extras:73a87a6-nccl-cuda12.2.2-ubuntu22.04-nccl2.19.3-1-torch2.2.0-vision0.17.0-audio2.2.0-flash_attn2.4.2 A40, 13B model.