Closed bianxuxuxu closed 1 week ago
SGLang used torch.compile and flashinfer together without any issues. You can check out some end-to-end examples here https://github.com/sgl-project/sglang/issues/1008
@merrymercy It seems that SGLang didn't open fullgraph=True in torch.compile : https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/model_executor/cuda_graph_runner.py#67. If I use mode="max-autotune-no-cudagraphs" in my above code like sglang, it can also succeed. But I don't know how these settings affects performance
Hi @bianxuxuxu , by looking at your code I think the problem exist in expressions like q[i,:,:]
, and you are trying to run batch prefill with single_prefill_with_kv_cache
:
attn_output = torch.cat(
[
flashinferattn(q[i,:,:].view(-1, self.num_heads, self.head_dim),
k[i,:,:].view(-1, self.num_heads, self.head_dim),
v[i,:,:].view(-1, self.num_heads, self.head_dim), self.scale) for i in torch.arange(0,batch)
], 0).view(batch, seq_len, hidden_size)
I don't recommend doing this (the single_prefill_with_kv_cache
is not designed for batch prefill and not cudagraph compatible) but you can do it anyways, you don't have to batch using for loop actually, there is a easier way (by fusing B into H dimension) where you only have to call one kernel:
# (batch_size * num_heads, seq_len, head_dim)
q_hnd = q.view(batch_size, seq_len, num_heads, head_dim).transpose(1, 2).contiguous().view(-1, seq_len, head_dim)
# (batch_size * num_heads, seq_len, head_dim)
k_hnd = k.view(batch_size, seq_len, num_heads, head_dim).transpose(1, 2).contiguous().view(-1, seq_len, head_dim)
# (batch_size * num_heads, seq_len, head_dim)
k_hnd = k.view(batch_size, seq_len, num_heads, head_dim).transpose(1, 2).contiguous().view(-1, seq_len, head_dim)
attn_output = single_prefill_with_kv_cache(q_hnd, k_hnd, v_hnd, layout="HND", scale=scale).view(batch, num_heads, seq_len, head_dim).transpose(1, 2).contiguous().view(batch_size, seq_len, head_dim)
You are encouraged to use https://docs.flashinfer.ai/api/python/prefill.html#flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper instead for batch prefill where we support variable length and is designed to be CUDAGraph compatible.
Hi @yzh119 . Thanks for your advice. But the q_hnd layout is incorrect, it must be in shape of [qo_len, num_qo_heads, head_dim] according to the comments :https://docs.flashinfer.ai/generated/flashinfer.prefill.single_prefill_with_kv_cache.html#flashinfer.prefill.single_prefill_with_kv_cache
I see, sorry about the confusion, then I think the following way should work:
q = q.view(batch_size, seq_len, num_heads, head_dim).transpose(0, 1).contiguous().view(seq_len, -1, head_dim)
The general idea is to fuse batch_size
with num_heads
, they are equivalent when seq_len
is a constant.
@yzh119 . Yes. This is correct and can run correctly. Is there any difference in performance between the implementation and using [BatchPrefillWithRaggedKVCacheWrapper]
BatchPrefillWithRaggedKVCacheWrapper
was designed for variable length (using Ragged Tensor which don't need padding) and has a load-balancing scheduler, and it's designed to be compatible with CUDAGraph (PyTorch's CUDAGraph requires a fixed kernel configuration, and single_prefill_with_kv_cache
uses dynamic kernel configuration).
Performance-wise, if the input length is not variable, they should be similar.
@bianxuxuxu In sglang, we use our own code to capture the cuda graph at the outermost level. We compare our perf with gpt-fast and find there is no performance degradation.
The latest sglang is faster than gpt-fast on H100 for any sequence length.
Close this issue for now, feel free to re-open it if there are still other concerns.
I have made the custom_ops according to pytorch-tutorials: https://pytorch.org/tutorials/advanced/python_custom_ops.html#python-custom-ops-tutorial. But it can't work. Here is my reproduction code:
partial log: