flashinfer-ai / flashinfer

FlashInfer: Kernel Library for LLM Serving
https://flashinfer.ai
Apache License 2.0
1.22k stars 115 forks source link

perf: initial cuda graph support #256

Closed yzh119 closed 4 months ago

yzh119 commented 4 months ago

As requested in #187 , this PR adds initial support of CUDAGraph compatibility of flashinfer batch decode attention kernels. This PR is the first step towards full CUDAGraph support and we will implement CUDAGraph compatible prefill operators in later PRs.

Proposed APIs

We add another wrapper CUDAGraphBatchDecodeWithPagedKVCacheWrapper, and user need to pre-allocation page data structure buffers to initialize this wrapper class. Once initiated, these buffers are pinned on GPUs in the life cycle of the wrapper class.

The behavior of CUDAGraphBatchDecodeWithPagedKVCacheWrapper is a little bit different from BatchDecodeWithPagedKVCacheWrapper's: we will only run a fixed set of kernels in CUDAGraph mode, no matter what the input shape is (the original implementation will dispatch to different kernels according to different input shapes).

This PR also fix the address of all kernel input pointers to accomodate the constraint of CUDAGraph capturing.

Examples

See test_cuda_graph_batch_decode_with_paged_kv_cache in unittests. begin_forward functions should not be captured as some of the operators are not allowed to be captured.

cc @AgrawalAmey @LiuXiaoxuanPKU @comaniac

yzh119 commented 4 months ago

Let's merge this PR first, and then iterate on updating this feature.