vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
27.89k stars 4.11k forks source link

[Help wanted] [Spec decode]: Increase acceptance rate via Medusa's typical acceptance #5015

Closed cadedaniel closed 3 months ago

cadedaniel commented 4 months ago

🚀 The feature, motivation and pitch

Speculative decoding allows emitting multiple tokens per sequence by speculating future tokens, scoring their likelihood using the LLM, and then accepting each speculative token based on its likelihood. This process is laid out in the following diagram: Screenshot 2024-05-23 at 1 45 16 PM

The problem with rejection sampling is that it holds a very high bar for quality: it is lossless and guarantees the distribution of the target model, even if it means rejecting plausible speculative tokens.

This issue is a request to implement Medusa's typical acceptance routing in vLLM. Typical acceptance trades off output quality to increase the acceptance rate. See "Choice of threshold in typical acceptance" in the Medusa blogpost for more information.

vLLM users should be able to toggle between different acceptance routines; they can use rejection sampling for tasks that require higher quality, or typical acceptance when speedup is more important.

NOTE: This acceptance routine should work with other proposal types (Eagle, draft, ngram, other), not just Medusa. The speculative decoding framework in vLLM may need improvements to the rejection sampling interface to support this.

Alternatives

No response

Additional context

vLLM's rejection sampler is implemented here: https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/rejection_sampler.py

sroy745 commented 4 months ago

I have started looking at this issue. I will be basing my implementation on this reference implementation in medusa (https://sourcegraph.com/github.com/FasterDecoding/Medusa@e2a5d20c048a9b0a4092e6933c34313687422518/-/blob/medusa/model/utils.py?L404)