flashinfer-ai / flashinfer

FlashInfer: Kernel Library for LLM Serving
https://flashinfer.ai
Apache License 2.0
1.11k stars 100 forks source link

Can we integrate Flashinfer into gpt-fast? #391

Closed jianc99 closed 1 week ago

jianc99 commented 1 month ago

Hi, in previous issues, you wrote that you planed to integrate flashinfer into some inference backend like gpt-fast. This will be very interesting! And may I ask can we integrate Flashinfer into gpt-fast now? Thanks

yzh119 commented 1 month ago

Hi I don't see there any difficulty of integrating flashinfer into gpt-fast. But we prefer a minimal example (e.g. within 1k loc) of continuous batching (gptfast don't support batching afaik), we are working with @sgl-project team on that.

jianc99 commented 1 month ago

Wow cool! Is there any example of using torch.compile with flashinfer BatchPrefillWithPagedKVCacheWrapper or Decode wrapper? Thanks!

yzh119 commented 1 month ago

You can check how vllm/sglang integrates flashinfer (they are good examples of how to use those wrapper functions in flashinfer), both of them uses pytorch's cudagraph capturing.

I haven't tried to use torch.compile together with flashinfer (please let me know if you have related experience), the only possible issue is whether torch compiler can deal with custom operators, we should follow pytorch's custom operator manual for better compatibility in the future.

jianc99 commented 1 month ago

Thanks for your information!

I have checked how to define custom operators, and I can successfully define single_prefill_with_kv_cache like below:

import torch
import flashinfer

torch.library.define(
    "mylib::custom_func",
    "(Tensor q, Tensor k_cache, Tensor v_cache) -> Tensor",
)

@torch.library.impl("mylib::custom_func", "cuda")
def custom_func(q, k_cache, v_cache):
    return flashinfer.single_prefill_with_kv_cache(
        q, k_cache, v_cache
    )

@torch.library.register_fake("mylib::custom_func")
def custom_func_abstract(q, k_cache, v_cache):
    return torch.empty_like(q)

with torch.device("cuda"):
    q = torch.randn((2, 2, 128), dtype=torch.bfloat16)
    k_cache = torch.randn((5, 2, 128), dtype=torch.bfloat16)
    v_cache = torch.randn((5, 2, 128), dtype=torch.bfloat16)

torch.compile(torch.ops.mylib.custom_func, fullgraph=True)(
    q, k_cache, v_cache
)

The problem is when using BatchPrefillWithPagedKVCacheWrapper, we have to first init the wrapper, and then perform the forward pass. This makes this kind of registration for the forward function difficult.

yzh119 commented 1 month ago

Thank you for the info, I think we can resolve this by making those wrappers pure python object. I'll refactor the codebase this weekend.

jianc99 commented 1 month ago

Thanks! I think I solved this problem by creating the wrapper before defining custom operator, and keeping using this wrapper. But make the wrapper python project will be fine and make things easier!

Btw, llama3.1 family have different positional encoding (rope scaling with two factors) compared with llama2 and llama3. Can we support llama3.1 in the next flashinfer? Thanks!

yzh119 commented 1 month ago

by creating the wrapper before defining custom operator

Ideally, the begin_forward and forward functions should be registered as custom operators as well, so it's preferable to make wrapper python objects so that we can make sure all of these function arguments are torch tensors.

different positional encoding

Yes it's easy to support, just stay tuned.

leng-yue commented 1 month ago

Cool!

Ying1123 commented 1 month ago

@jianc99 flashinfer + torch.compile is supported in sglang and it is very fast.

You can try

python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B --enable-torch-compile

It is faster than the original gpt-fast, TensorRT-LLM, much faster than vLLM. It also supports all online serving features such as dynamic batching and prefix caching.