flashinfer-ai / flashinfer

FlashInfer: Kernel Library for LLM Serving
https://flashinfer.ai
Apache License 2.0
1.48k stars 147 forks source link

[Question] very small performance gain for cascade append on GQA #595

Closed hewr1993 closed 2 weeks ago

hewr1993 commented 2 weeks ago

For Llama3-70B TP8, we have 8 q-heads and 1 k-head.

Assuming we have 4000 shared prefix tokens with batch 8, cascade decoding is much slower than baseline (26us vs 19us). But if we set k-heads to 8 (i.e. MHA), cascade performance would be really good (26us vs 55us).

Please enlight me if 2-level cascade attention only works on MHA, or I'm doing anything wrong.

yzh119 commented 2 weeks ago

Hi @hewr1993 which API you are using for cascade decoding?

hewr1993 commented 2 weeks ago

Hi @hewr1993 which API you are using for cascade decoding?

flashinfer.MultiLevelCascadeAttentionWrapper

@yzh119 Please let me know if there's any trouble reproducing the performance issue

yzh119 commented 2 weeks ago

Hi @hewr1993 I think I can reproduce the issue.

If you change k-heads to 8 and q-heads to 64 (which is still a GQA setting), I think you will observed different performance, cascade decoding is expected to run faster. The main problem is: when k-heads is 1, the kernel duration is too small, cascade decoding will launch 3 kernels and vanilla decoding will only launch one, and kernel launching time will be non-trivial for cascade inference. (using cudagraph might alleviate the issue).