Closed ZSL98 closed 1 week ago
I got the same question. I am instantiate the SinglePrefillWithKVCacheDispatched function, but found that it has static_assert(sizeof(DTypeIn) == 2);
check. @yzh119 Does this for some implementation consideration?
The decode attention operators support fp32, we just need to add fp32 to this macro: https://github.com/flashinfer-ai/flashinfer/blob/5a38066f171a6c6932fe73693f161e14614e1eea/python/csrc/pytorch_extension_utils.h#L36-L51
For prefill/append attention, it's a little bit tricky, because many instructions such as ldmatrix
(https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-ldmatrix) only supports 16bits, which makes it non-trivial to load fp32 tiles (especially the transposed load) from shared memory to registers. An option is to convert fp32 input to bf16 and use bf16 prefill attention kernels, we can design an api that accepts bf16/fp16 input and returns fp32 output in flashinfer.
The decode attention operators support fp32, we just need to add fp32 to this macro:
For prefill/append attention, it's a little bit tricky, because many instructions such as
ldmatrix
(https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-ldmatrix) only supports 16bits, which makes it non-trivial to load fp32 tiles (especially the transposed load) from shared memory to registers. An option is to convert fp32 input to bf16 and use bf16 prefill attention kernels, we can design an api that accepts bf16/fp16 input and returns fp32 output in flashinfer.
Got it, my use case is prefill/append kernel and it looks tricky indeed. Thanks for your kind reply. I think the support of fp32 output sounds great and helpful!
The examples are all tensors of half() type. I wonder if flashinfer supports fp32 dtype?