@abcdabcd987 found a test case that may fail the sampling kernel, and that is because we use BLOCK_SCAN_RAKING_MEMOIZE algorithm for CUB's InclusiveScan API, which do not guarantee that
the output is monotonically increasing when inputs are greater than or equal to zero, and our sampling
algorithm's correctness relies on the monotonicity of prefix sums.
The break of monotonicity might because of the property of floating point numbers, the floating point aggregation is neither associative nor communicative, and different aggregation order may get you different results, so sum(prob[0..100]) might be greater than sum(prob[0..101]) if we use different ways to aggregate them.
This PR changes to use BLOCK_SCAN_WARP_SCANS, which is the only algorithm we observe that guarantees the monotonicity of prefix sum results.
@abcdabcd987 found a test case that may fail the sampling kernel, and that is because we use
BLOCK_SCAN_RAKING_MEMOIZE
algorithm for CUB's InclusiveScan API, which do not guarantee that the output is monotonically increasing when inputs are greater than or equal to zero, and our sampling algorithm's correctness relies on the monotonicity of prefix sums.The break of monotonicity might because of the property of floating point numbers, the floating point aggregation is neither associative nor communicative, and different aggregation order may get you different results, so sum(prob[0..100]) might be greater than sum(prob[0..101]) if we use different ways to aggregate them.
This PR changes to use
BLOCK_SCAN_WARP_SCANS
, which is the only algorithm we observe that guarantees the monotonicity of prefix sum results.co-authored-by: Lequn Chen lqchen@cs.washington.edu