flashinfer-ai / flashinfer

FlashInfer: Kernel Library for LLM Serving
https://flashinfer.ai
Apache License 2.0
1.48k stars 147 forks source link

Runtime error with single_prefill_with_kv_cache while Compilation #541

Open YudiZh opened 1 month ago

YudiZh commented 1 month ago

I tried to compile single_prefill_with_kv_cache using torch.compile.

import torch
from flashinfer import single_prefill_with_kv_cache

data_type = torch.bfloat16

QH=64
KH=8
S=1024
D=128

def generate_data():
    q = torch.randn(S, QH, D, device='cuda', dtype=data_type)
    k = torch.randn(S, KH, D, device='cuda', dtype=data_type)
    v = torch.randn(S, KH, D, device='cuda', dtype=data_type)
    return q, k, v

def timed(fn):
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    start.record()
    result = fn()
    end.record()
    torch.cuda.synchronize()
    return result, start.elapsed_time(end) / 1000

torch.library.define(
    "mylib::custom_func_flashinfer",
    "(Tensor q, Tensor k, Tensor v, Tensor custom_mask) -> Tensor",
)

@torch.library.impl("mylib::custom_func_flashinfer", "cuda")
def custom_func_flashinfer(q, k, v, custom_mask):
    return single_prefill_with_kv_cache(
        q, k, v, custom_mask=custom_mask
    )

@torch.library.impl_abstract("mylib::custom_func_flashinfer")
def custom_func_flashinfer_abstract(q, k, v, custom_mask):
    return torch.empty_like(q)

def attn(q, k, v, custom_mask=None):
    return torch.ops.mylib.custom_func_flashinfer(q, k, v, custom_mask=custom_mask)
attn = torch.compile(attn, mode="reduce-overhead", fullgraph=True)

for i in range(10):
    q, k, v = generate_data()
    mask = torch.tril(
        torch.full((S, S), True, device="cuda:0"),
    )
    o, run_time = timed(lambda: attn(q, k, v, custom_mask=mask))
    print(run_time)

cause following runtime error

/data/home/ydzhang/project/code_test/flashinfer_test/compilation.py:37: FutureWarning: `torch.library.impl_abstract` was renamed to `torch.library.register_fake`. Please use that instead; we will remove `torch.library.impl_abstract` in a future version of PyTorch.
  @torch.library.impl_abstract("mylib::custom_func_flashinfer")
Traceback (most recent call last):
  File "/data/home/ydzhang/project/code_test/flashinfer_test/compilation.py", line 52, in <module>
    o, run_time = timed(lambda: attn(q, k, v, custom_mask=mask))
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/home/ydzhang/project/code_test/flashinfer_test/compilation.py", line 21, in timed
    result = fn()
             ^^^^
  File "/data/home/ydzhang/project/code_test/flashinfer_test/compilation.py", line 52, in <lambda>
    o, run_time = timed(lambda: attn(q, k, v, custom_mask=mask))
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/home/ydzhang/miniconda3/envs/torch24/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 433, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/data/home/ydzhang/project/code_test/flashinfer_test/compilation.py", line 42, in attn
    def attn(q, k, v, custom_mask=None):
  File "/data/home/ydzhang/miniconda3/envs/torch24/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 600, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/data/home/ydzhang/miniconda3/envs/torch24/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 987, in forward
    return compiled_fn(full_args)
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/data/home/ydzhang/miniconda3/envs/torch24/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 217, in runtime_wrapper
    all_outs = call_func_at_runtime_with_args(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/home/ydzhang/miniconda3/envs/torch24/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/utils.py", line 120, in call_func_at_runtime_with_args
    out = normalize_as_list(f(args))
                            ^^^^^^^
  File "/data/home/ydzhang/miniconda3/envs/torch24/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 451, in wrapper
    return compiled_fn(runtime_args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/home/ydzhang/miniconda3/envs/torch24/lib/python3.11/site-packages/torch/_inductor/codecache.py", line 1131, in __call__
    return self.current_callable(inputs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/home/ydzhang/miniconda3/envs/torch24/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 993, in run
    return compiled_fn(new_inputs)
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/home/ydzhang/miniconda3/envs/torch24/lib/python3.11/site-packages/torch/_inductor/cudagraph_trees.py", line 373, in deferred_cudagraphify
    fn, out = cudagraphify(model, inputs, new_static_input_idxs, *args, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/home/ydzhang/miniconda3/envs/torch24/lib/python3.11/site-packages/torch/_inductor/cudagraph_trees.py", line 403, in cudagraphify
    return manager.add_function(
           ^^^^^^^^^^^^^^^^^^^^^
  File "/data/home/ydzhang/miniconda3/envs/torch24/lib/python3.11/site-packages/torch/_inductor/cudagraph_trees.py", line 2089, in add_function
    return fn, fn(inputs)
               ^^^^^^^^^^
  File "/data/home/ydzhang/miniconda3/envs/torch24/lib/python3.11/site-packages/torch/_inductor/cudagraph_trees.py", line 1841, in run
    out = self._run(new_inputs, function_id)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/home/ydzhang/miniconda3/envs/torch24/lib/python3.11/site-packages/torch/_inductor/cudagraph_trees.py", line 1932, in _run
    return self.run_eager(new_inputs, function_id)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/home/ydzhang/miniconda3/envs/torch24/lib/python3.11/site-packages/torch/_inductor/cudagraph_trees.py", line 2055, in run_eager
    return node.run(new_inputs)
           ^^^^^^^^^^^^^^^^^^^^
  File "/data/home/ydzhang/miniconda3/envs/torch24/lib/python3.11/site-packages/torch/_inductor/cudagraph_trees.py", line 646, in run
    check_memory_pool(self.device_index, self.cuda_graphs_pool, out_refs)
  File "/data/home/ydzhang/miniconda3/envs/torch24/lib/python3.11/site-packages/torch/_inductor/cudagraph_trees.py", line 1699, in check_memory_pool
    raise RuntimeError(msg)
RuntimeError: These live storage data ptrs are in the cudagraph pool but not accounted for as an output of cudagraph trees: 

Data Pointer: 22959854977024, history: 
yzh119 commented 1 month ago

Hi @hnyls2002 have you ever met such errors before in sglang integration?

yzh119 commented 1 month ago

Seems the issue is with custom_mask, which internally calls flashinfer.packbits: https://github.com/flashinfer-ai/flashinfer/blob/78e26e47b95bea994ad2a47e1b1f42810363429c/python/flashinfer/prefill.py#L277-L281

Could you try specifying packed_custom_mask argument instead of custom_mask argument? Or decorating flashinfer.packbits with pytorch cuda ops API as well.

YudiZh commented 1 month ago

Even when I only pass the q k v arguments and omit the others, the error still occurs

torch.library.define(
    "mylib::custom_func_flashinfer",
    "(Tensor q, Tensor k, Tensor v) -> Tensor",
)

@torch.library.impl("mylib::custom_func_flashinfer", "cuda")
def custom_func_flashinfer(q, k, v):
    return single_prefill_with_kv_cache(
        q, k, v
    )

@torch.library.impl_abstract("mylib::custom_func_flashinfer")
def custom_func_flashinfer_abstract(q, k, v):
    return torch.empty_like(q)

def attn(q, k, v):
    return torch.ops.mylib.custom_func_flashinfer(q, k, v)
attn = torch.compile(attn, mode="reduce-overhead", fullgraph=True)
abcdabcd987 commented 1 month ago

@YudiZh Can you try torch.compile(..., fullgraph=True, mode="max-autotune-no-cudagraphs")? Cuda Graph provides little value when you are just capturing one CUDA kernel.

BTW, we are adding torch library annotations in https://github.com/flashinfer-ai/flashinfer/pull/554

YudiZh commented 1 month ago

Thank you for your response. I have tried using torch.compile(..., fullgraph=True, mode="max-autotune-no-cudagraphs"), and the code runs without errors. However, when I aim to implement CUDA graphs for flashinfer and other PyTorch operations within the model's forward function, which mode should I use to achieve compilation? Does using max-autotune-no-cudagraphs result in PyTorch operations not reaching the expected acceleration when CUDA graphs are not involved? Here is a sample of my code:

import torch
import flashinfer
import torch.nn as nn
from torch import Tensor
data_type = torch.bfloat16

def generate_data_x():
    x = torch.randn(1, 1024, 4096, device='cuda', dtype=data_type)
    return x

def timed(fn):
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    start.record()
    result = fn()
    end.record()
    torch.cuda.synchronize()
    return result, start.elapsed_time(end) / 1000

class Attention(nn.Module):
    def __init__(self):
        super().__init__()
        total_head_dim = (32 + 2 * 8) * 128
        self.wqkv = nn.Linear(4096, total_head_dim).to(torch.bfloat16)
        self.attn = flashinfer.single_prefill_with_kv_cache

    def forward(self, x: Tensor) -> Tensor:
        bsz, seqlen, _ = x.shape
        q, k, v = self.wqkv(x).split([4096, 8*128, 8*128], dim=-1)

        q = q.view(bsz, seqlen, 32, 128)
        k = k.view(bsz, seqlen, 8, 128)
        v = v.view(bsz, seqlen, 8, 128)

        q = q.squeeze(0)
        k = k.squeeze(0)
        v = v.squeeze(0)

        y = self.attn(q, k, v)
        return y

self_attn = Attention().to("cuda")
attn = lambda model, x: model(x)
attn = torch.compile(attn, mode="reduce-overhead", fullgraph=True)

for i in range(10):
    x = generate_data_x()
    o, run_time = timed(lambda: attn(self_attn, x))
    print(run_time)
yzh119 commented 1 month ago

The apis start with single_ are not compatible with cudagraphs (I might spend some time to make them compatible when cudagraphs later). The BatchPrefill/BatchDecode wrappers have been designed to be compatible with CUDAGraph.