flashinfer-ai / flashinfer

FlashInfer: Kernel Library for LLM Serving
https://flashinfer.ai
Apache License 2.0
1.1k stars 98 forks source link

Make flashinfer kernels cuda graphs friendly #187

Closed AgrawalAmey closed 2 months ago

AgrawalAmey commented 5 months ago

Thanks for creating these awesome kernels! I am trying to get flashinfer kernels to work with cuda graphs. But it appears that several parallelism decisions (block size, num_q_tiles, etc.) are made on the fly based on the input data in the forward function. This makes it difficult to capture flashinfer kernels in cuda graphs in a generic manner. I think one solution to the problem would be to introduce a launcher kernel which would factor in the input metadata and launch the actual the actual cuda kernel using dynamic parallelism. Towards that, following are the items I have identified --

1. BatchPrefillWithPagedKVCachePyTorchWrapper::Forward -- handle return lse?
2. BatchPrefillWithPagedKVCachePyTorchWrapper::Forward -- paged_kv_t batch_size should not be on cpu side
3. BatchPrefillWithPagedKVCacheWrapperDispatched -- make cuda device function or get rid of it
4. BatchPrefillWithPagedKVCacheWrapperDispatched -- num_frags_x, num_qo_tiles, batch size need to be 
5. BatchPrefillWithPagedKVCacheWrapperDispatched -- do not access handler state directly in the function
6. BatchPrefillWithPagedKVCacheDispatched -- make cuda device function
7. BatchPrefillWithPagedKVCacheDispatched -- put num_qo_tiles on device accessible memory
8. BatchPrefillWithPagedKVCacheDispatched -- Make validations gpu friendly
9. Batch size should be explicit input parameter not be based on length of indptr, so that inputs can be padded.

@yzh119 please let me know what would be the best way to proceed?

yzh119 commented 5 months ago

Hi @AgrawalAmey , thanks for bringing this up, I have some ideas about the CUDA graph integration with flashinfer:

The kernels to be executed can be determined before the a decode/prefill step (for all layers) by analyze the shapes, we can compile the CUDA Graph for all possible combinations (not too many) ahead of time, and dispatch to one of them according to the shapes.

Regarding dynamic parallelism:

introduce a launcher kernel which would factor in the input metadata and launch the actual the actual cuda kernel using dynamic parallelism

It sounds tricky to me because the required shared memory size/grid size varies for different schedules.

AgrawalAmey commented 5 months ago

Hi @yzh119!

I have one implementation in sarathi-serve which tries to list different combinations, and capture them. But with increasing batch size and big variance in input sequences, the number of possibilities seemed explode. Plus, prefill + decode requests clubbed together makes it further more challenging. The memory cost of cuda graphs becomes too high as the number of combinations increases.

The child kernel/dynamic parallelism proposal is aimed to solve the challenge with different grid size etc. Essentially, the launcher kernel will be triggered with a single warp. Inside the launcher kernel, we can determine all the launch params and launch the actual attention kernel.

AgrawalAmey commented 5 months ago

A sample program to explain what I mean:

#include <cuda_runtime.h>
#include <iostream>

__global__ void subKernel(int *data) {
    printf("Data before sub kernel: %d\n", *data);
    (*data) -= 1;
}

__global__ void addKernel(int *data) {
    printf("Data before add kernel: %d\n", *data);
    (*data) += 1;
}

struct UserData {
    int data;
    bool op;
};

__global__ void launchChildKernelFromDevice(void *_userData) {
    UserData *userData = (UserData *)_userData;
    bool op = userData->op;

    if (op) {
        addKernel<<<1, 1>>>((int*)userData);
    } else {
        subKernel<<<1, 1>>>((int*)userData);
    }
}

int main() {
    cudaStream_t stream;
    cudaStreamCreate(&stream);

    UserData *userData;
    cudaMallocHost(&userData, sizeof(UserData));

    userData->data = 10;
    userData->op = true;

    // run add kernel for sanity check

    cudaStreamSynchronize(stream);
    std::cout << "Data before kernel: " << userData->data << std::endl;
    launchChildKernelFromDevice<<<1, 1, 0, stream>>>(userData);
    cudaStreamSynchronize(stream);
    std::cout << "Data after kernel: " << userData->data << std::endl;

    cudaGraph_t graph;
    cudaGraphExec_t instance;

    // Begin graph capture
    cudaStreamSynchronize(stream);
    cudaStreamBeginCapture(stream, cudaStreamCaptureModeGlobal);

    // Use cuda host function to launch child kernel
    launchChildKernelFromDevice<<<1, 1, 0, stream>>>(userData);

    // End graph capture
    cudaStreamEndCapture(stream, &graph);
    cudaGraphInstantiate(&instance, graph, NULL, NULL, 0);

    cudaStreamSynchronize(stream);

    printf("Data after graph: %d\n", userData->data);

    // Run the graph
    cudaGraphLaunch(instance, stream);
    cudaStreamSynchronize(stream);

    printf("Data after graph replay: %d\n", userData->data);

    userData->op = false;
    cudaGraphLaunch(instance, stream);
    cudaStreamSynchronize(stream);

    printf("Data after graph replay with different op: %d\n", userData->data);

    cudaGraphExecDestroy(instance);
    cudaGraphDestroy(graph);
    cudaStreamDestroy(stream);
    cudaFree(userData);

    return 0;
}
yzh119 commented 5 months ago

Thanks for your explaination, that's sounds reasonable.

To proceed, I'd love to write some documentations on our dispatching rules and see if we can describe them in dynamic parallelism. Before that I have to make #75 done because it will affect our dispatching strategy.

I'll be glad to follow up next week and we can schedule a meeting on zoom (you can drop me an email at zhye@cs.washington.edu).

AgrawalAmey commented 5 months ago

Yes, that would be great, I will send out a when2meet link on email, thank you!

ZSL98 commented 5 months ago

Hi, @AgrawalAmey, will your sarathi or sarathi-serve be open-sourced?

AgrawalAmey commented 5 months ago

Hey @ZSL98, we are working with the vLLM team to get Sarathi-Serve scheduler support inside vLLM

yzh119 commented 2 months ago

The CUDA graph compatibility was resolved in https://github.com/flashinfer-ai/flashinfer/releases/tag/v0.0.5

The current strategy is:

  1. If batch size is large enough, we don't use split-k and the kernel can be traced properly (we need to reuse the indptr/indices/last_page_len buffer, this was properly handled in PyTorch wrapper APIs).
  2. If batch size is small, we use split-k, but the grid size must be fixed. We set the grid size to a fix value which is computed from the number of SMs in GPU, and we have a kernel argument block_valid_mask to determine whether to skip one threadblock's computation to achieve runtime dynamism.
AgrawalAmey commented 2 months ago

@yzh119 thanks a lot for all the amazing work! I wanted to understand split-k behaves when the sequence length is significantly different between capture and replay time. For instance, if during capture we have seq length of 1k and during replay we have a seq of length 100k, would the parallelization parameters get applied appropriately?

yzh119 commented 2 months ago

Yes, they will be properly handled.

When cudagraph is enabled, we decides whether to split-k only on batch size (for decode) and query lengths (for append), not on kv-cache length, that said, so it's safe to capture when kv-cache length is small (we have test cases for capturing for small kv-length and replay with long: https://github.com/flashinfer-ai/flashinfer/blob/231b1dc89cd5cdaa485ac3c701cbd93ab9c05a90/python/tests/test_batch_decode_kernels.py#L136-L286). once batch-size/query lengths are determined, the kernel grid size are fixed and we use block valid mask for dynamic parallelism:

yzh119 commented 2 months ago

There is one tricky part about prefill kernels, we pass kv_chunk_size as input arguments and it doesn't work in CUDA graph mode because the input argument is fixed at capture time. So we pass a pointer to a global memory address that stores this value instead: https://github.com/flashinfer-ai/flashinfer/blob/231b1dc89cd5cdaa485ac3c701cbd93ab9c05a90/include/flashinfer/attention/prefill.cuh#L1362

and we will change its value in BeginForward functions during the generation process.

AgrawalAmey commented 2 months ago

This is great! Thanks a lot for the in-depth description. I will go ahead and add cuda graph support in sarathi-serve based on this.