flashinfer-ai / flashinfer

FlashInfer: Kernel Library for LLM Serving
https://flashinfer.ai
Apache License 2.0
1.1k stars 98 forks source link

perf: optimize warp layout for prefill operator for small query length #185

Closed yzh119 closed 2 months ago

yzh119 commented 5 months ago

Currently our warp layout in prefill operators is fixed to 4x1. However, this schedule is sub-optimal when the query length per request is less than 64:

  1. When query length <= 16, it's better to use warp layout 1x4.
  2. When query length <= 32, it's better to use warp layout 2x2.

A threadblock-level synchronization function is required to merge states across warps for both cases.

Other changes:

  1. Previously we use threadIdx.y to denote the warp_idx, now we use both threadIdx.y and threadIdx.z to accomodate different warp layouts ((blockDim.y, blockDim.z)).
Qubitium commented 3 months ago

@yzh119 Lol. Your commit titles are hilarious. But might want to censor the output a little. Would hate to have this awesome project be filtered out due to "safety" guidelines by some major corps.

yzh119 commented 3 months ago

@Qubitium sorry about it, I'll squash all intermediate commits.

These titles are meaningless and I don't want people get bothered by them, and thank you again for your reminder, as an open-source project I should be more professional. I apologize if anyone get notified by any of these commits..

yzh119 commented 3 months ago

I did a rebase and squash, now random commit messages are cleared, thank you again for your reminder.

I'll wrap up this PR soon, and GQA kernels could be further optimized with these changes.