flashinfer-ai / flashinfer

FlashInfer: Kernel Library for LLM Serving
https://flashinfer.ai
Apache License 2.0
1.46k stars 143 forks source link

[Feature Request] Add an argument to control the number of CTAs used in attention APIs #591

Open yzh119 opened 2 weeks ago

yzh119 commented 2 weeks ago

Nanoflow overlaps decode/prefill/communication by limiting the number of SMs each kernel uses (in practice it's controlled by grid size), current nanoflow implementation modifies flashinfer kernels to support launching flashinfer kernels with specified grid size.

As flashinfer changes all kernel implementation to persistent kernels, we can support specifying the number of SM's at flashinfer side. More specifically, we can add an argument num_ctas at our plan functions to specify the grid size, and user can directly control it in Python.

The benefit of this feature include:

  1. Keep nanoflow's development in pace with latest flashinfer features (JIT/FA3/customization/etc).
  2. Making it possible to port nanoflow to pytorch. It may sacrifice some performance but I think overall it's good for nanoflow's adoption.
  3. Making it possible to use nanoflow-style parallelism in other llm serving frameworks such as vllm/sglang/mlc-llm/etc.

We also need to support such arguments in GEMM APIs by wrapping cutlass gemm implementations, leave them for future work.

cc @serendipity-zk @happierpig