flashinfer-ai / flashinfer

FlashInfer: Kernel Library for LLM Serving
https://flashinfer.ai
Apache License 2.0
1.46k stars 143 forks source link

feat: support cached cos/sin in rope APIs #585

Closed yzh119 closed 2 weeks ago

yzh119 commented 2 weeks ago

As requested in #530 , this PR implements the RoPE with cached cos/sin embeddings, which is more flexible in some use cases.

In our previous RoPE implementations, cos/sin values are computed on-the-fly inside kernels with float32 instead using cached values.

In this PR we found that if we use f16 cos/sin cache, the rope result will have a large discrepancy compared to our original implementation flashinfer.apply_rope (which stores cos/sin with fp32). So we require the cos_cache and sin_cache to use fp32 data type.

cc @dreaming-panda @ByronHsu