flashinfer-ai / flashinfer

FlashInfer: Kernel Library for LLM Serving
https://flashinfer.ai
Apache License 2.0
1.14k stars 102 forks source link

apply_rope_inplace will cause graphbreak due to mutated inputs #403

Open jianc99 opened 1 month ago

jianc99 commented 1 month ago
import torch
import flashinfer

rope = flashinfer.apply_rope_inplace

torch.library.define(
     "mylib::target_rope",
     "(Tensor(a!) q, Tensor(a!) k, Tensor indptr, Tensor offsets) -> None",
)
@torch.library.impl("mylib::target_rope", "cuda")
def target_rope(q, k, indptr, offsets):
     rope(q, k, indptr, offsets, interleave=True)

@torch.library.register_fake("mylib::target_rope")
def target_rope_abstract(q, k, indptr, offsets):
     return None

q = torch.randn(4, 4, 128, dtype=torch.bfloat16).to(0)
k = torch.randn(4, 1, 128, dtype=torch.bfloat16).to(0)
indptr = torch.arange(5, dtype=torch.int32).to(0)
offsets = torch.full((4,), 1, dtype=torch.int32).to(0)

torch.compile(torch.ops.mylib.target_rope, mode="reduce-overhead", fullgraph=True)(q, k, indptr, offsets)
skipping cudagraphs due to mutated inputs (2 instances)
yzh119 commented 1 month ago

Mutated arguments have be annotated: https://pytorch.org/tutorials/advanced/cpp_custom_ops.html#creating-mutable-operators

yzh119 commented 1 month ago

I noticed that you already annotated the mutated inputs. I think it's okay to expose another set of apply_rope and apply_llama31_rope which are not inplace operations for pytorch compile.

jianc99 commented 1 month ago

Yeah I have annotated that but it still not works. Exposing non in place rope will be much helpful, thanks!

yzh119 commented 1 month ago

Done in #405 .