Open yz-tang opened 1 week ago
It's easy to add support for fp16:
In https://github.com/flashinfer-ai/flashinfer/blob/d81af9775e56bb30152b17770e804823cddfc279/python/csrc/sampling.cu#L39-L40 (and all other functions in this file), we cast all inputs to fp32 https://github.com/flashinfer-ai/flashinfer/blob/d81af9775e56bb30152b17770e804823cddfc279/python/csrc/sampling.cu#L33-L40. To use fp16 kernels, we just need to dispatch different data types using the dispatch macro (https://github.com/flashinfer-ai/flashinfer/blob/d81af9775e56bb30152b17770e804823cddfc279/python/csrc/pytorch_extension_utils.h#L26)
But as you mentioned, fp16 might fail some extreme cases because the fp16 probabilities might not sum up to 1 anymore.
I found test_sampling.cu, there is only for FP32 test。I try use FP16, It not work.