flashinfer-ai / flashinfer

FlashInfer: Kernel Library for LLM Serving
https://flashinfer.ai
Apache License 2.0
1.32k stars 121 forks source link

[Question] Sampling kernel only support FP32 now? #531

Open yz-tang opened 1 week ago

yz-tang commented 1 week ago

I found test_sampling.cu, there is only for FP32 test。I try use FP16, It not work.

yzh119 commented 6 days ago

It's easy to add support for fp16:

In https://github.com/flashinfer-ai/flashinfer/blob/d81af9775e56bb30152b17770e804823cddfc279/python/csrc/sampling.cu#L39-L40 (and all other functions in this file), we cast all inputs to fp32 https://github.com/flashinfer-ai/flashinfer/blob/d81af9775e56bb30152b17770e804823cddfc279/python/csrc/sampling.cu#L33-L40. To use fp16 kernels, we just need to dispatch different data types using the dispatch macro (https://github.com/flashinfer-ai/flashinfer/blob/d81af9775e56bb30152b17770e804823cddfc279/python/csrc/pytorch_extension_utils.h#L26)

But as you mentioned, fp16 might fail some extreme cases because the fp16 probabilities might not sum up to 1 anymore.