As requested in #187 , this PR adds initial support of CUDAGraph compatibility of flashinfer batch decode attention kernels. This PR is the first step towards full CUDAGraph support and we will implement CUDAGraph compatible prefill operators in later PRs.
Proposed APIs
We add another wrapper CUDAGraphBatchDecodeWithPagedKVCacheWrapper, and user need to pre-allocation page data structure buffers to initialize this wrapper class. Once initiated, these buffers are pinned on GPUs in the life cycle of the wrapper class.
The behavior of CUDAGraphBatchDecodeWithPagedKVCacheWrapper is a little bit different from BatchDecodeWithPagedKVCacheWrapper's: we will only run a fixed set of kernels in CUDAGraph mode, no matter what the input shape is (the original implementation will dispatch to different kernels according to different input shapes).
This PR also fix the address of all kernel input pointers to accomodate the constraint of CUDAGraph capturing.
Examples
See test_cuda_graph_batch_decode_with_paged_kv_cache in unittests.
begin_forward functions should not be captured as some of the operators are not allowed to be captured.
As requested in #187 , this PR adds initial support of
CUDAGraph
compatibility of flashinfer batch decode attention kernels. This PR is the first step towards full CUDAGraph support and we will implement CUDAGraph compatible prefill operators in later PRs.Proposed APIs
We add another wrapper
CUDAGraphBatchDecodeWithPagedKVCacheWrapper
, and user need to pre-allocation page data structure buffers to initialize this wrapper class. Once initiated, these buffers are pinned on GPUs in the life cycle of the wrapper class.The behavior of
CUDAGraphBatchDecodeWithPagedKVCacheWrapper
is a little bit different fromBatchDecodeWithPagedKVCacheWrapper
's: we will only run a fixed set of kernels in CUDAGraph mode, no matter what the input shape is (the original implementation will dispatch to different kernels according to different input shapes).This PR also fix the address of all kernel input pointers to accomodate the constraint of CUDAGraph capturing.
Examples
See
test_cuda_graph_batch_decode_with_paged_kv_cache
in unittests.begin_forward
functions should not be captured as some of the operators are not allowed to be captured.cc @AgrawalAmey @LiuXiaoxuanPKU @comaniac