flashinfer-ai / flashinfer

FlashInfer: Kernel Library for LLM Serving
https://flashinfer.ai
Apache License 2.0
760 stars 64 forks source link

How are prefill and decode kernels different? #323

Closed AgrawalAmey closed 1 week ago

AgrawalAmey commented 2 weeks ago

With the latest change to support kv-parallelism in prefill kernel, is there still need for a separation in prefill and decode kernels? I have been running some tests, and it looks like the prefill kernel is almost always faster at decode than the decode kernel.

# %%
import torch
import flashinfer

# %%
# profiling params
num_warmup_iters = 5
num_active_iters = 10

# model params
num_qo_heads = 64 // 8
num_kv_heads = 8 // 8
head_dim = 128

# flashinfer params
page_size = 16
workspace_size = 128 * 1024 * 1024

# input params
seq_length =  1024
batch_size = 128

# %%
decode_workspace_buffer = torch.empty(workspace_size, dtype=torch.uint8, device="cuda")
decode_wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
    decode_workspace_buffer, "NHD"
)

prefill_workspace_buffer = torch.empty(workspace_size, dtype=torch.uint8, device="cuda")
prefill_wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
    prefill_workspace_buffer, "NHD"
)

# %%
num_pages_per_seq = (seq_length + page_size - 1) // page_size
max_num_pages = num_pages_per_seq * batch_size
q = torch.randn(batch_size, num_qo_heads, head_dim).half().to("cuda")
kv_cache = torch.randn(
    max_num_pages, 2, page_size, num_kv_heads, head_dim, dtype=torch.float16, device="cuda"
)

# %%
def run_decode_kernel():
    kv_page_indices = torch.arange(max_num_pages).int().to("cuda")    
    kv_page_indptr = torch.arange(0, max_num_pages + 1, num_pages_per_seq).int().to("cuda")
    last_page_len = seq_length % page_size
    # 1 <= kv_last_page_len <= page_size
    kv_last_page_len = torch.ones(batch_size, dtype=torch.int32, device="cuda") * last_page_len

    # create auxiliary data structures for batch decode attention
    decode_wrapper.begin_forward(
        kv_page_indptr,
        kv_page_indices,
        kv_last_page_len,
        num_qo_heads,
        num_kv_heads,
        head_dim,
        page_size,
        data_type=torch.float16
    )

    for _ in range(num_warmup_iters):
        decode_wrapper.forward(q, kv_cache)

    torch.cuda.synchronize()

    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)

    start_event.record()
    for _ in range(num_active_iters):
        out = decode_wrapper.forward(q, kv_cache)
    end_event.record()

    torch.cuda.synchronize()
    elapsed_time_ms = start_event.elapsed_time(end_event)

    print(f"Elapsed time: {elapsed_time_ms / num_active_iters} ms")

    # clear auxiliary data structures
    decode_wrapper.end_forward()

    return out

# %%
decode_output = run_decode_kernel()

# %%
def run_prefill_kernel():
    qo_indptr = torch.arange(0, batch_size + 1, 1, dtype=torch.int32, device="cuda")
    kv_page_indices = torch.arange(max_num_pages).int().to("cuda")    
    kv_page_indptr = torch.arange(0, max_num_pages + 1, num_pages_per_seq).int().to("cuda")
    last_page_len = seq_length % page_size
    # 1 <= kv_last_page_len <= page_size
    kv_last_page_len = torch.ones(batch_size, dtype=torch.int32, device="cuda") * last_page_len

    # create auxiliary data structures for batch decode attention
    prefill_wrapper.begin_forward(
        qo_indptr,
        kv_page_indptr,
        kv_page_indices,
        kv_last_page_len,
        num_qo_heads,
        num_kv_heads,
        head_dim,
        page_size,
    )

    for _ in range(num_warmup_iters):
        prefill_wrapper.forward(q, kv_cache)

    torch.cuda.synchronize()

    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)

    start_event.record()
    for _ in range(num_active_iters):
        out = prefill_wrapper.forward(q, kv_cache)
    end_event.record()

    torch.cuda.synchronize()
    elapsed_time_ms = start_event.elapsed_time(end_event)

    print(f"Elapsed time: {elapsed_time_ms / num_active_iters} ms")

    # clear auxiliary data structures
    prefill_wrapper.end_forward()

    return out

# %%
prefill_output = run_prefill_kernel()

# %%
assert torch.allclose(decode_output, prefill_output, atol=1e-3, rtol=1e-3)
yzh119 commented 1 week ago

Prefill kernels use tensor cores while decode kernels use cuda cores, that's the only difference.

The prefill kernels use more registers and shared memory and registers than decode kernels, thus the number of pipeline stages will be less than decode kernels, and there is some extra overhead of loading query from shared to registers for prefill kernels (pin query to register for small query length is some optimization I should do but I haven't done unfortunately) per iteration.

But tensor cores have higher throughput, GQA has higher operational intensity than MHA so using tensor cores might be beneficial in some cases, but it suffers some overhead I mentioned before. So it's case by case.

AgrawalAmey commented 1 week ago

I see, thanks a lot for the detailed comment.

yzh119 commented 1 week ago

An gentle reminder, v0.0.5 has some silly bugs for split-k and may result in unstable performance measurement, please check v0.0.6 instead: https://github.com/flashinfer-ai/flashinfer/releases/tag/v0.0.6