flashinfer-ai / flashinfer

FlashInfer: Kernel Library for LLM Serving
https://flashinfer.ai
Apache License 2.0
1.14k stars 102 forks source link

failed to execute torch.compile #482

Closed bianxuxuxu closed 1 week ago

bianxuxuxu commented 2 weeks ago

I have made the custom_ops according to pytorch-tutorials: https://pytorch.org/tutorials/advanced/python_custom_ops.html#python-custom-ops-tutorial. But it can't work. Here is my reproduction code:

import torch
import torch.nn as nn
import flashinfer

@torch.library.custom_op("mylib::flashinferattn", mutates_args=())
def flashinferattn(q_tensor: torch.Tensor, k_tensor: torch.Tensor, v_tensor: torch.Tensor, scale: float) -> torch.Tensor:
    return flashinfer.single_prefill_with_kv_cache(q_tensor, k_tensor, v_tensor, sm_scale=scale)

@flashinferattn.register_fake
def _(q_tensor, k_tensor, v_tensor, scale):
    return torch.empty_like(q_tensor)

class Attention(nn.Module):
    def __init__(
        self, inner_dim, head_num,
    ):
        super(Attention, self).__init__()

        self.num_heads = head_num
        self.head_dim = inner_dim//self.num_heads

        self.scale = self.head_dim**-0.5

        self.to_q = nn.Linear(inner_dim, inner_dim, bias=False)
        self.to_k = nn.Linear(inner_dim, inner_dim, bias=False)
        self.to_v = nn.Linear(inner_dim, inner_dim, bias=False)

        self.to_out = nn.Linear(inner_dim, inner_dim)

    @torch.no_grad()
    def forward(self, hidden_states):

        q = self.to_q(hidden_states)
        k = self.to_k(hidden_states)
        v = self.to_v(hidden_states)

        batch, seq_len, hidden_size = q.size()

        attn_output = torch.cat(
                                [
                                    flashinferattn(q[i,:,:].view(-1, self.num_heads, self.head_dim),
                                                    k[i,:,:].view(-1, self.num_heads, self.head_dim), 
                                                    v[i,:,:].view(-1, self.num_heads, self.head_dim), self.scale) for i in torch.arange(0,batch)
                                ], 0).view(batch, seq_len, hidden_size)

        attn_output = self.to_out(attn_output)

        return attn_output

torch.compiler.reset()
attn = Attention(inner_dim=640, head_num=10).cuda().half()
attn_compiled = torch.compile(attn, mode="reduce-overhead", fullgraph=True)
#batch=2, seq-len=128, hidden-size=640
inputs = torch.rand((2,128,640),device="cuda:0", dtype=torch.float16)
attn_out = attn_compiled(inputs)
print(attn_out)

partial log:

RuntimeError: Failed running call_function <built-in function getitem>(*(FakeTensor(..., device='cuda:0', size=(2, 128, 640), dtype=torch.float16), (FakeTensor(..., size=(), dtype=torch.int64), slice(None, None, None), slice(None, None, None))), **{}):
aten._local_scalar_dense.default

During handling of the above exception, another exception occurred:

    raise Unsupported(msg)
torch._dynamo.exc.Unsupported: data dependent operator: aten._local_scalar_dense.default; to enable, set torch._dynamo.config.capture_scalar_outputs = True

from user code:
merrymercy commented 2 weeks ago

SGLang used torch.compile and flashinfer together without any issues. You can check out some end-to-end examples here https://github.com/sgl-project/sglang/issues/1008

bianxuxuxu commented 2 weeks ago

@merrymercy It seems that SGLang didn't open fullgraph=True in torch.compile : https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/model_executor/cuda_graph_runner.py#67. If I use mode="max-autotune-no-cudagraphs" in my above code like sglang, it can also succeed. But I don't know how these settings affects performance

yzh119 commented 2 weeks ago

Hi @bianxuxuxu , by looking at your code I think the problem exist in expressions like q[i,:,:], and you are trying to run batch prefill with single_prefill_with_kv_cache:

        attn_output = torch.cat(
                                [
                                    flashinferattn(q[i,:,:].view(-1, self.num_heads, self.head_dim),
                                                    k[i,:,:].view(-1, self.num_heads, self.head_dim), 
                                                    v[i,:,:].view(-1, self.num_heads, self.head_dim), self.scale) for i in torch.arange(0,batch)
                                ], 0).view(batch, seq_len, hidden_size)

I don't recommend doing this (the single_prefill_with_kv_cache is not designed for batch prefill and not cudagraph compatible) but you can do it anyways, you don't have to batch using for loop actually, there is a easier way (by fusing B into H dimension) where you only have to call one kernel:

# (batch_size * num_heads, seq_len, head_dim)
q_hnd = q.view(batch_size, seq_len, num_heads, head_dim).transpose(1, 2).contiguous().view(-1, seq_len, head_dim)
# (batch_size * num_heads, seq_len, head_dim)
k_hnd = k.view(batch_size, seq_len, num_heads, head_dim).transpose(1, 2).contiguous().view(-1, seq_len, head_dim)
# (batch_size * num_heads, seq_len, head_dim)
k_hnd = k.view(batch_size, seq_len, num_heads, head_dim).transpose(1, 2).contiguous().view(-1, seq_len, head_dim)
attn_output = single_prefill_with_kv_cache(q_hnd, k_hnd, v_hnd, layout="HND", scale=scale).view(batch, num_heads, seq_len, head_dim).transpose(1, 2).contiguous().view(batch_size, seq_len, head_dim)

You are encouraged to use https://docs.flashinfer.ai/api/python/prefill.html#flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper instead for batch prefill where we support variable length and is designed to be CUDAGraph compatible.

bianxuxuxu commented 2 weeks ago

Hi @yzh119 . Thanks for your advice. But the q_hnd layout is incorrect, it must be in shape of [qo_len, num_qo_heads, head_dim] according to the comments :https://docs.flashinfer.ai/generated/flashinfer.prefill.single_prefill_with_kv_cache.html#flashinfer.prefill.single_prefill_with_kv_cache

yzh119 commented 2 weeks ago

I see, sorry about the confusion, then I think the following way should work:

q = q.view(batch_size, seq_len, num_heads, head_dim).transpose(0, 1).contiguous().view(seq_len, -1, head_dim)

The general idea is to fuse batch_size with num_heads, they are equivalent when seq_len is a constant.

bianxuxuxu commented 2 weeks ago

@yzh119 . Yes. This is correct and can run correctly. Is there any difference in performance between the implementation and using [BatchPrefillWithRaggedKVCacheWrapper]

yzh119 commented 2 weeks ago

BatchPrefillWithRaggedKVCacheWrapper was designed for variable length (using Ragged Tensor which don't need padding) and has a load-balancing scheduler, and it's designed to be compatible with CUDAGraph (PyTorch's CUDAGraph requires a fixed kernel configuration, and single_prefill_with_kv_cache uses dynamic kernel configuration).

Performance-wise, if the input length is not variable, they should be similar.

merrymercy commented 2 weeks ago

@bianxuxuxu In sglang, we use our own code to capture the cuda graph at the outermost level. We compare our perf with gpt-fast and find there is no performance degradation.

The latest sglang is faster than gpt-fast on H100 for any sequence length.

yzh119 commented 1 week ago

Close this issue for now, feel free to re-open it if there are still other concerns.