flashinfer-ai / flashinfer

FlashInfer: Kernel Library for LLM Serving
https://flashinfer.ai
Apache License 2.0
760 stars 64 forks source link

[Roadmap] FlashInfer v0.1.0 release checklist #19

Open yzh119 opened 7 months ago

yzh119 commented 7 months ago

Expected release date: Mar 15th, 2024

General

  1. [x] Support general page table layout (@yzh119 )
  2. [ ] sm70/75 compatibility (@yzh119 )
  3. [ ] performance: using fp16 as intermediate data type to accelerate decode attention on A100 (@yzh119 )
  4. [x] Accelerate batch prefill & decode for the extreme case that page_size equals one. (@yzh119 )
  5. [ ] Sliding Window Attention (@yzh119 )
  6. [ ] Do not allocate CUDA memory inside FlashInfer APIs using native CUDA alloc function because it will interfere with the memory planning in serving engines (if exists), the preferred behavior would be letting the user allocate all buffers outside the FlashInfer APIs. (@yzh119 )
  7. [x] Remove num_layers and layer_id from data structures (@yzh119 in fc0726c7fcd86b225d05eb03176c4136f024000c)
  8. [x] Further accelerate decode kernels (@yzh119 in 2a3d6d038d3ee57904b150c0ad107556568c2c1f, b83b40825b53da1c94039503f772150af9c2fccb)
  9. [ ] Prefill/append kernels accelerated by TMA and fp8 tensor cores in H100.

MLC-Serving

  1. [x] Cascade attention and TVM wrappers

Atom

Required operators for paper Atom: Low-bit Quantization for Efficient and Accurate LLM Serving:

  1. [ ] int4 kv-cache flashinfer decode operator. (@happierpig and @yzh119 )

Punica

Required operators for paper Punica: Multi-Tenant LoRA Serving:

  1. [ ] SGMV shrink & expand, more shapes and accelerations (@yzh119 )
  2. [ ] 4-bit SGMV.
  3. [ ] Fuse backbone GEMM and lora computation (@yzh119 )
  4. [ ] Dequant operators. (@yzh119 )
  5. [ ] Optional: Deploy multiple llama-adapater models.

Quest

Required operators for Quest:

  1. [ ] Head-wise page indices.

Other hardware backends

masahi commented 7 months ago

Hi @yzh119, I'm wondering if the flashinfer kernel can be implemented over the vllm's paged KV cache. Does one of the item, "Support general page table layout", address such issue?

The vllm paged cache is described in https://github.com/mlc-ai/mlc-llm/blob/main/mlc_llm/relax_model/llama_batched_vllm.py#L307-L322. It's much simpler than Relax's one.

This is going to be helpful for comparing flashinfer and vllm decode attention kernels in an apple-to-apple manner. I'm also interested in batched prefill with KV cache support.

yzh119 commented 7 months ago

Hi @masahi , thanks for bringing this up, yes we will have a unified interface that's compatible with both vllm and the current page table design.

Batch prefill with paged kv is already supported: https://github.com/flashinfer-ai/flashinfer/blob/11364ca4c3ce651dd544efff3225906fe15c5b8a/include/flashinfer/prefill.cuh#L944-L1107

masahi commented 7 months ago

yes we will have a unified interface that's compatible with both vllm and the current page table design.

Great! cc @vinx13 @sunggg

Batch prefill with paged kv is already supported

Yes, I was aware of that and this is what interests me the most right now. I need an attention kernel that does both of the followings, and Flash attention and vllm only do one of them.

Is BatchPrefillWithPagedKVCacheKernel supposed to be useful for speculative decoding?

I also have other use case for such kernel, in parallel sampling (generate multiple sequences for one prompt). More context in https://github.com/vllm-project/vllm/pull/12#issuecomment-1841952270

yzh119 commented 7 months ago

One of the use cases would be speculative decoding (maybe we need another mask input). And yes there are other interesting use cases, I'll showcase one of them in the next few days.

sunggg commented 3 months ago

Hey, @yzh119! Great work! We are interested in using FlashInfer for the tree decoding, do you have a plan to support the custom attention mask?