Closed AgrawalAmey closed 2 months ago
Hi @AgrawalAmey , thanks for bringing this up, I have some ideas about the CUDA graph integration with flashinfer:
The kernels to be executed can be determined before the a decode/prefill step (for all layers) by analyze the shapes, we can compile the CUDA Graph for all possible combinations (not too many) ahead of time, and dispatch to one of them according to the shapes.
Regarding dynamic parallelism:
introduce a launcher kernel which would factor in the input metadata and launch the actual the actual cuda kernel using dynamic parallelism
It sounds tricky to me because the required shared memory size/grid size varies for different schedules.
Hi @yzh119!
I have one implementation in sarathi-serve which tries to list different combinations, and capture them. But with increasing batch size and big variance in input sequences, the number of possibilities seemed explode. Plus, prefill + decode requests clubbed together makes it further more challenging. The memory cost of cuda graphs becomes too high as the number of combinations increases.
The child kernel/dynamic parallelism proposal is aimed to solve the challenge with different grid size etc. Essentially, the launcher kernel will be triggered with a single warp. Inside the launcher kernel, we can determine all the launch params and launch the actual attention kernel.
A sample program to explain what I mean:
#include <cuda_runtime.h>
#include <iostream>
__global__ void subKernel(int *data) {
printf("Data before sub kernel: %d\n", *data);
(*data) -= 1;
}
__global__ void addKernel(int *data) {
printf("Data before add kernel: %d\n", *data);
(*data) += 1;
}
struct UserData {
int data;
bool op;
};
__global__ void launchChildKernelFromDevice(void *_userData) {
UserData *userData = (UserData *)_userData;
bool op = userData->op;
if (op) {
addKernel<<<1, 1>>>((int*)userData);
} else {
subKernel<<<1, 1>>>((int*)userData);
}
}
int main() {
cudaStream_t stream;
cudaStreamCreate(&stream);
UserData *userData;
cudaMallocHost(&userData, sizeof(UserData));
userData->data = 10;
userData->op = true;
// run add kernel for sanity check
cudaStreamSynchronize(stream);
std::cout << "Data before kernel: " << userData->data << std::endl;
launchChildKernelFromDevice<<<1, 1, 0, stream>>>(userData);
cudaStreamSynchronize(stream);
std::cout << "Data after kernel: " << userData->data << std::endl;
cudaGraph_t graph;
cudaGraphExec_t instance;
// Begin graph capture
cudaStreamSynchronize(stream);
cudaStreamBeginCapture(stream, cudaStreamCaptureModeGlobal);
// Use cuda host function to launch child kernel
launchChildKernelFromDevice<<<1, 1, 0, stream>>>(userData);
// End graph capture
cudaStreamEndCapture(stream, &graph);
cudaGraphInstantiate(&instance, graph, NULL, NULL, 0);
cudaStreamSynchronize(stream);
printf("Data after graph: %d\n", userData->data);
// Run the graph
cudaGraphLaunch(instance, stream);
cudaStreamSynchronize(stream);
printf("Data after graph replay: %d\n", userData->data);
userData->op = false;
cudaGraphLaunch(instance, stream);
cudaStreamSynchronize(stream);
printf("Data after graph replay with different op: %d\n", userData->data);
cudaGraphExecDestroy(instance);
cudaGraphDestroy(graph);
cudaStreamDestroy(stream);
cudaFree(userData);
return 0;
}
Thanks for your explaination, that's sounds reasonable.
To proceed, I'd love to write some documentations on our dispatching rules and see if we can describe them in dynamic parallelism. Before that I have to make #75 done because it will affect our dispatching strategy.
I'll be glad to follow up next week and we can schedule a meeting on zoom (you can drop me an email at zhye@cs.washington.edu).
Yes, that would be great, I will send out a when2meet link on email, thank you!
Hi, @AgrawalAmey, will your sarathi or sarathi-serve be open-sourced?
Hey @ZSL98, we are working with the vLLM team to get Sarathi-Serve scheduler support inside vLLM
The CUDA graph compatibility was resolved in https://github.com/flashinfer-ai/flashinfer/releases/tag/v0.0.5
The current strategy is:
block_valid_mask
to determine whether to skip one threadblock's computation to achieve runtime dynamism.@yzh119 thanks a lot for all the amazing work! I wanted to understand split-k behaves when the sequence length is significantly different between capture and replay time. For instance, if during capture we have seq length of 1k and during replay we have a seq of length 100k, would the parallelization parameters get applied appropriately?
Yes, they will be properly handled.
When cudagraph is enabled, we decides whether to split-k only on batch size (for decode) and query lengths (for append), not on kv-cache length, that said, so it's safe to capture when kv-cache length is small (we have test cases for capturing for small kv-length and replay with long: https://github.com/flashinfer-ai/flashinfer/blob/231b1dc89cd5cdaa485ac3c701cbd93ab9c05a90/python/tests/test_batch_decode_kernels.py#L136-L286). once batch-size/query lengths are determined, the kernel grid size are fixed and we use block valid mask for dynamic parallelism:
There is one tricky part about prefill kernels, we pass kv_chunk_size
as input arguments and it doesn't work in CUDA graph mode because the input argument is fixed at capture time. So we pass a pointer to a global memory address that stores this value instead: https://github.com/flashinfer-ai/flashinfer/blob/231b1dc89cd5cdaa485ac3c701cbd93ab9c05a90/include/flashinfer/attention/prefill.cuh#L1362
and we will change its value in BeginForward
functions during the generation process.
This is great! Thanks a lot for the in-depth description. I will go ahead and add cuda graph support in sarathi-serve based on this.
Thanks for creating these awesome kernels! I am trying to get flashinfer kernels to work with cuda graphs. But it appears that several parallelism decisions (block size, num_q_tiles, etc.) are made on the fly based on the input data in the forward function. This makes it difficult to capture flashinfer kernels in cuda graphs in a generic manner. I think one solution to the problem would be to introduce a launcher kernel which would factor in the input metadata and launch the actual the actual cuda kernel using dynamic parallelism. Towards that, following are the items I have identified --
@yzh119 please let me know what would be the best way to proceed?