Open YudiZh opened 1 month ago
Hi @hnyls2002 have you ever met such errors before in sglang integration?
Seems the issue is with custom_mask
, which internally calls flashinfer.packbits
:
https://github.com/flashinfer-ai/flashinfer/blob/78e26e47b95bea994ad2a47e1b1f42810363429c/python/flashinfer/prefill.py#L277-L281
Could you try specifying packed_custom_mask
argument instead of custom_mask
argument? Or decorating flashinfer.packbits
with pytorch cuda ops API as well.
Even when I only pass the q k v arguments and omit the others, the error still occurs
torch.library.define(
"mylib::custom_func_flashinfer",
"(Tensor q, Tensor k, Tensor v) -> Tensor",
)
@torch.library.impl("mylib::custom_func_flashinfer", "cuda")
def custom_func_flashinfer(q, k, v):
return single_prefill_with_kv_cache(
q, k, v
)
@torch.library.impl_abstract("mylib::custom_func_flashinfer")
def custom_func_flashinfer_abstract(q, k, v):
return torch.empty_like(q)
def attn(q, k, v):
return torch.ops.mylib.custom_func_flashinfer(q, k, v)
attn = torch.compile(attn, mode="reduce-overhead", fullgraph=True)
@YudiZh Can you try torch.compile(..., fullgraph=True, mode="max-autotune-no-cudagraphs")
? Cuda Graph provides little value when you are just capturing one CUDA kernel.
BTW, we are adding torch library annotations in https://github.com/flashinfer-ai/flashinfer/pull/554
Thank you for your response. I have tried using torch.compile(..., fullgraph=True, mode="max-autotune-no-cudagraphs")
, and the code runs without errors. However, when I aim to implement CUDA graphs for flashinfer and other PyTorch operations within the model's forward function, which mode should I use to achieve compilation? Does using max-autotune-no-cudagraphs
result in PyTorch operations not reaching the expected acceleration when CUDA graphs are not involved? Here is a sample of my code:
import torch
import flashinfer
import torch.nn as nn
from torch import Tensor
data_type = torch.bfloat16
def generate_data_x():
x = torch.randn(1, 1024, 4096, device='cuda', dtype=data_type)
return x
def timed(fn):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
result = fn()
end.record()
torch.cuda.synchronize()
return result, start.elapsed_time(end) / 1000
class Attention(nn.Module):
def __init__(self):
super().__init__()
total_head_dim = (32 + 2 * 8) * 128
self.wqkv = nn.Linear(4096, total_head_dim).to(torch.bfloat16)
self.attn = flashinfer.single_prefill_with_kv_cache
def forward(self, x: Tensor) -> Tensor:
bsz, seqlen, _ = x.shape
q, k, v = self.wqkv(x).split([4096, 8*128, 8*128], dim=-1)
q = q.view(bsz, seqlen, 32, 128)
k = k.view(bsz, seqlen, 8, 128)
v = v.view(bsz, seqlen, 8, 128)
q = q.squeeze(0)
k = k.squeeze(0)
v = v.squeeze(0)
y = self.attn(q, k, v)
return y
self_attn = Attention().to("cuda")
attn = lambda model, x: model(x)
attn = torch.compile(attn, mode="reduce-overhead", fullgraph=True)
for i in range(10):
x = generate_data_x()
o, run_time = timed(lambda: attn(self_attn, x))
print(run_time)
The apis start with single_
are not compatible with cudagraphs (I might spend some time to make them compatible when cudagraphs later).
The BatchPrefill
/BatchDecode
wrappers have been designed to be compatible with CUDAGraph.
I tried to compile single_prefill_with_kv_cache using torch.compile.
cause following runtime error