Closed hewr1993 closed 2 weeks ago
Hi @hewr1993 which API you are using for cascade decoding?
Hi @hewr1993 which API you are using for cascade decoding?
flashinfer.MultiLevelCascadeAttentionWrapper
@yzh119 Please let me know if there's any trouble reproducing the performance issue
Hi @hewr1993 I think I can reproduce the issue.
If you change k-heads to 8 and q-heads to 64 (which is still a GQA setting), I think you will observed different performance, cascade decoding is expected to run faster. The main problem is: when k-heads is 1, the kernel duration is too small, cascade decoding will launch 3 kernels and vanilla decoding will only launch one, and kernel launching time will be non-trivial for cascade inference. (using cudagraph might alleviate the issue).
For Llama3-70B TP8, we have 8 q-heads and 1 k-head.
Assuming we have 4000 shared prefix tokens with batch 8, cascade decoding is much slower than baseline (26us vs 19us). But if we set k-heads to 8 (i.e. MHA), cascade performance would be really good (26us vs 55us).
Please enlight me if 2-level cascade attention only works on MHA, or I'm doing anything wrong.