flashinfer-ai / flashinfer

FlashInfer: Kernel Library for LLM Serving
https://flashinfer.ai
Apache License 2.0
1.39k stars 127 forks source link

Support vLLM-style rope #530

Open ByronHsu opened 2 weeks ago

ByronHsu commented 2 weeks ago

As part of SGLang Issue #1487, SGLang plans to move vLLM to optional dependencies and use flashinfer as the main dependency.

I am working on moving rope to flashinfer. My plan is to reuse most of the existing vllm rope but replace ops.rotary_embedding and ops.batch_rotary_embedding with flashinfer's kernel, which can be found here.

However, I've noticed some gaps between the vLLM and flashinfer implementations:

  1. Cos_sin_cache: vLLM pre-computes the cos_sin_cache in the constructor, whereas flashinfer computes it on-the-fly.
  2. Offsets and indptr instead of positions: It's tricky to convert positions back to offsets + indptr. Can we support positions directly?
  3. Partial rotate: vLLM supports a partial rotation where the rotary dimension is less than the head dimension.
  4. Batched rope for multi-Lora: For more context, see vLLM pull request #3095.

In general, we can prioritize the first three issues and consider the fourth as a stretch goal.

yzh119 commented 2 weeks ago

Hi @ByronHsu , thanks for your suggestions, I think 1 & 3 are easy to support.

For 1, we can adding sin_cache and cos_cache as optional fields to the rope apis. For long context, there might be some numerical issues with f16 sin/cos cache so we should also support f32 sin/cos cache (Our current on-the-fly sin/cos computation uses f32).

For 3, yes we can add another rope_dim field for partial rope.

Can you give a concrete example of 2?

yzh119 commented 2 weeks ago

Okay I think I understand 2 now, for example, if batch_size=3, and indptr=[0, 1, 5, 10], and offsets=[4, 6, 3]. Then a equivalent positions would be: [4, 6, 7, 8, 9, 3, 4, 5, 6, 7].

Is that true?

ByronHsu commented 2 weeks ago

Okay I think I understand 2 now, for example, if batch_size=3, and indptr=[0, 1, 5, 10], and offsets=[4, 6, 3]. Then a equivalent positions would be: [4, 6, 7, 8, 9, 3, 4, 5, 6, 7].

Yes exactly! Thank you for the prompt response! All sounds good to me.

One comment: Can we separate the new API from the current 4 flashinfer's rope functions and provide the exact same interface with vLLM? Several reasons:

  1. apply_rope_inplace only implements the formula on the original paper, but in reality there are much more variants
  2. It makes all vllm's kernel users easy to migrate

Maybe we can call this apply_rope_inplace_with_cache, which does not calculate rope on the fly and support my proposed features

def apply_rope_inplace_with_cache(
    positions: torch.Tensor,
    query: torch.Tensor,
    key: torch.Tensor,
    head_size: int,
    cos_sin_cache: torch.Tensor,
    is_neox: bool,
) -> None:
   ...
ByronHsu commented 2 weeks ago

I did a global search and found ops.batch_rotary_embedding is not used in SGLang (looks like not in vLLM too). So we can safely skip 4th feature. thanks!