flashinfer-ai / flashinfer

FlashInfer: Kernel Library for LLM Serving
https://flashinfer.ai
Apache License 2.0
1.48k stars 147 forks source link

Improve parallelism in RoPE with pos_ids #609

Closed nandor closed 2 weeks ago

nandor commented 2 weeks ago

The previous kernel was not parallelised sufficiently well for low batch sizes. Similarly to the regular rotary kernel, now all qo/kv heads are split across separate blocks.

In decode mode, the pos_ids kernel is now faster.

yzh119 commented 2 weeks ago

Hi @nandor we use such parallelism mainly to save sin/cos computation time (same sin/cos can be reused for multiple heads). I expect using different threadblock for different heads will be faster for small batch size.

Would you mind running https://github.com/flashinfer-ai/flashinfer/blob/32d9510d67187f1f3a379cce81302cdd15a557d2/benchmarks/bench_rope.py ?

nandor commented 2 weeks ago

You are right - saving sin and cos across 2-8 heads does yield a small speedup. But the finer-grained computation is significantly faster on an H100 already.

Unfortunately this sort of batching is a bit more convoluted to implement in CUDA than triton and internally we'll be relying on a Triton kernel instead.

james-p-xu commented 1 week ago

I'm actually seeing that this change causes a correctness issue wrt apply_rope_pos_ids.

Here's a sample comparison script, passing prior to this commit hash (32d9510d67187f1f3a379cce81302cdd15a557d2) but failing post-change: https://github.com/sgl-project/sglang/blob/dd0d2a3af4967880362e3bad9d95cd14572c89ea/scripts/playground/compare_flashinfer_vllm_rope.py

yzh119 commented 1 week ago

@james-p-xu I'll fix it, thank you!