flashinfer-ai / flashinfer

FlashInfer: Kernel Library for LLM Serving
https://flashinfer.ai
Apache License 2.0
1.1k stars 98 forks source link

misc: fused kernel for sampling and normalization functions #207

Closed yzh119 closed 4 months ago

yzh119 commented 4 months ago

This PR implements fused kernel for sampling functions (categorical sampling, top-k sampling, top-p sampling) of normalization functions (RMSNorm).

The top-k and top-p implementation uses reject sampling so that we don't need to use GPU sort. Our current implementation do not pin probability on shared memory, which could be further improved in the future.