flashinfer-ai / flashinfer

FlashInfer: Kernel Library for LLM Serving
https://flashinfer.ai
Apache License 2.0
822 stars 77 forks source link

bugfix: Fix the correctness issue of sampling kernel #225

Closed yzh119 closed 2 months ago

yzh119 commented 2 months ago

@abcdabcd987 found a test case that may fail the sampling kernel, and that is because we use BLOCK_SCAN_RAKING_MEMOIZE algorithm for CUB's InclusiveScan API, which do not guarantee that the output is monotonically increasing when inputs are greater than or equal to zero, and our sampling algorithm's correctness relies on the monotonicity of prefix sums.

The break of monotonicity might because of the property of floating point numbers, the floating point aggregation is neither associative nor communicative, and different aggregation order may get you different results, so sum(prob[0..100]) might be greater than sum(prob[0..101]) if we use different ways to aggregate them.

This PR changes to use BLOCK_SCAN_WARP_SCANS, which is the only algorithm we observe that guarantees the monotonicity of prefix sum results.

co-authored-by: Lequn Chen lqchen@cs.washington.edu