flashinfer-ai / flashinfer

FlashInfer: Kernel Library for LLM Serving
https://flashinfer.ai
Apache License 2.0
1.1k stars 98 forks source link

Does flashinfer support float datatype? #191

Closed ZSL98 closed 1 week ago

ZSL98 commented 5 months ago

The examples are all tensors of half() type. I wonder if flashinfer supports fp32 dtype?

chenzhuofu commented 3 months ago

I got the same question. I am instantiate the SinglePrefillWithKVCacheDispatched function, but found that it has static_assert(sizeof(DTypeIn) == 2); check. @yzh119 Does this for some implementation consideration?

yzh119 commented 3 months ago

The decode attention operators support fp32, we just need to add fp32 to this macro: https://github.com/flashinfer-ai/flashinfer/blob/5a38066f171a6c6932fe73693f161e14614e1eea/python/csrc/pytorch_extension_utils.h#L36-L51

For prefill/append attention, it's a little bit tricky, because many instructions such as ldmatrix (https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-ldmatrix) only supports 16bits, which makes it non-trivial to load fp32 tiles (especially the transposed load) from shared memory to registers. An option is to convert fp32 input to bf16 and use bf16 prefill attention kernels, we can design an api that accepts bf16/fp16 input and returns fp32 output in flashinfer.

chenzhuofu commented 3 months ago

The decode attention operators support fp32, we just need to add fp32 to this macro:

https://github.com/flashinfer-ai/flashinfer/blob/5a38066f171a6c6932fe73693f161e14614e1eea/python/csrc/pytorch_extension_utils.h#L36-L51

For prefill/append attention, it's a little bit tricky, because many instructions such as ldmatrix (https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-ldmatrix) only supports 16bits, which makes it non-trivial to load fp32 tiles (especially the transposed load) from shared memory to registers. An option is to convert fp32 input to bf16 and use bf16 prefill attention kernels, we can design an api that accepts bf16/fp16 input and returns fp32 output in flashinfer.

Got it, my use case is prefill/append kernel and it looks tricky indeed. Thanks for your kind reply. I think the support of fp32 output sounds great and helpful!