As requested in #530 , this PR implements the RoPE with cached cos/sin embeddings, which is more flexible in some use cases.
In our previous RoPE implementations, cos/sin values are computed on-the-fly inside kernels with float32 instead using cached values.
In this PR we found that if we use f16 cos/sin cache, the rope result will have a large discrepancy compared to our original implementation flashinfer.apply_rope (which stores cos/sin with fp32). So we require the cos_cache and sin_cache to use fp32 data type.
As requested in #530 , this PR implements the RoPE with cached cos/sin embeddings, which is more flexible in some use cases.
In our previous RoPE implementations, cos/sin values are computed on-the-fly inside kernels with float32 instead using cached values.
In this PR we found that if we use f16 cos/sin cache, the rope result will have a large discrepancy compared to our original implementation
flashinfer.apply_rope
(which stores cos/sin with fp32). So we require thecos_cache
andsin_cache
to use fp32 data type.cc @dreaming-panda @ByronHsu