flashinfer-ai / flashinfer

FlashInfer: Kernel Library for LLM Serving
https://flashinfer.ai
Apache License 2.0
822 stars 77 forks source link

bugfix: fix the potential issue of sampling kernels #226

Closed yzh119 closed 2 months ago

yzh119 commented 2 months ago

In #225 we use BLOCK_SCAN_WARP_SCANS to make sure prefix sum result is monotonic, however, we found there are still cases that InclusiveSum with BLOCK_SCAN_WARP_SCANS algorithm still do not return monotonic output.

In this PR, we fix the behavior in another way: apply prefix sum on a pair (value, greater_than_0) instead of only value, and write i to output only when value[i] > u && value[i - 1] <= u and prob[i] > 0. If there are multiple i that satisfy this condition (because of the floating point numerical issues), we select the smallest i.