vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
21.91k stars 3.09k forks source link

[Speculative decoding] [Help wanted] [Performance] Optimize draft-model speculative decoding #4630

Open cadedaniel opened 1 month ago

cadedaniel commented 1 month ago

Proposal to improve performance

With the end-to-end correctness tests merged in https://github.com/vllm-project/vllm/pull/3951, now we will optimize the implementation to get ~50% speedup on 70B model with temperature 1.0.

Work required:

P0/P1 -- priority (Small/Medium/Large) -- relative size estimate

FAQ

What should the target configuration be for 50% speedup?

In the Anyscale fork we saw a 50% speedup on bs=8 with a 68m-sized draft model on TP1/70B target model on TP8 and a 7B draft model on TP(1|8)/70B target model on TP8. This was with the optimizations listed above as "P0".

Note we can do much better than this, with multi-query scoring (P1), GQA for target model scoring, and a dynamic speculation policy. This is just the starting point!

Why not implement Medusa / tree-attention?

We should implement this! The work here will lay the foundation for future improvements in speculative decoding. For example, Eagle uses the Medusa approach (fine-tuned heads plus tree attention) and even claims to beat Medusa. But for Eagle to work well in vLLM we need to optimize the sampler as listed above.

The north star should be: configurable tree size (top-k .. top-1), which uses multi-query attention for scoring (no batch expansion). This issue is about optimizing vLLM in the top-1 speculation case to get 50% speedup with draft models.

youkaichao commented 1 month ago

Support draft model on different tensor-parallel-size than target model

This should be doable. Just need to figure out the UX change of how users use it.

Do spec workers and non-spec workers share process/device? e.g. when we have tp=8 in current code, and want to add another tp=2 for spec decoding, do we want tp=2 to be another 2 processes, or from the subset of the tp=8 processes?

cadedaniel commented 1 month ago

See the code linked here @youkaichao : https://github.com/vllm-project/vllm/issues/4632. The spec worker and non-spec workers share the same process.

KexinFeng commented 1 month ago

About the tree-attention/Medusa/Eagle, one of the core implementation will be tree attention mask in flash attention, which is currently not ready. I'd like to bring your attention to it https://github.com/Dao-AILab/flash-attention/issues/924. If anyone would like to contribute to it, it would be great.

sighingnow commented 1 month ago

In the Anyscale fork we saw a 50% speedup on bs=8 with a 68m-sized draft model on TP1/70B target model on TP8 and a 7B draft model on TP(1|8)/70B target model on TP8. This was with the optimizations listed above as "P0".

Hi @cadedaniel, I have tried current main branch to evaluate the acceleration of speculative decoding, but encountered the following assertion error:

https://github.com/vllm-project/vllm/blob/190bc838e17196733526896bf2861f8d05bd3f43/vllm/executor/ray_gpu_executor.py#L28-L32

I'm wondering how the 50% speedup is measured, is there still further pending PRs? And, as the draft-model looks so small (64m-sized), may I know if the 50% speedup is measured with greedy sampling or random sampling?

Thanks!

cadedaniel commented 1 month ago

About the tree-attention/Medusa/Eagle, one of the core implementation will be tree attention mask in flash attention, which is currently not ready. I'd like to bring your attention to it https://github.com/Dao-AILab/flash-attention/issues/924. If anyone would like to contribute to it, it would be great.

@LiuXiaoxuanPKU has more on this

cadedaniel commented 1 month ago

@sighingnow this issue is for getting the 50% speedup. once the P0s are done we will get it with temperature 1.0.

ChuanhongLi commented 1 month ago

In the Anyscale fork we saw a 50% speedup on bs=8 with a 68m-sized draft model on TP1/70B target model on TP8 and a 7B draft model on TP(1|8)/70B target model on TP8. This was with the optimizations listed above as "P0".

Hi @cadedaniel, I have tried current main branch to evaluate the acceleration of speculative decoding, but encountered the following assertion error:

https://github.com/vllm-project/vllm/blob/190bc838e17196733526896bf2861f8d05bd3f43/vllm/executor/ray_gpu_executor.py#L28-L32

I'm wondering how the 50% speedup is measured, is there still further pending PRs? And, as the draft-model looks so small (64m-sized), may I know if the 50% speedup is measured with greedy sampling or random sampling?

Thanks!

I have met the same problem. Is there a solution? By the way, is there any documentation on how to evaluate the acceleration of speculative decoding? Thanks!

sighingnow commented 1 month ago

@sighingnow this issue is for getting the 50% speedup. once the P0s are done we will get it with temperature 1.0.

May I know more about the accept rate when we get the 50% speedup? Thanks!

cadedaniel commented 1 month ago

May I know more about the accept rate when we get the 50% speedup? Thanks!

On llama2 7b / llama2 70b, the acceptance rate was like 80% (no fine tuning). we trained a 68m draft model at anyscale that gets ~50% acceptance rate. btw you can run acceptance rate experiments today (I will push a PR tomorrow for TP>1 support)

I have met the same problem. Is there a solution? By the way, is there any documentation on how to evaluate the acceleration of speculative decoding? Thanks!

Thanks @ChuanhongLi -- FYI there is no acceleration yet. we'll share documentation once there is a useful speedup.

sighingnow commented 1 month ago

On llama2 7b / llama2 70b, the acceptance rate was like 80% (no fine tuning). we trained a 68m draft model at anyscale that gets ~50% acceptance rate. btw you can run acceptance rate experiments today (I will push a PR tomorrow for TP>1 support)

Thanks for the information! Looking forward to the complete speculative decoding support!

ChuanhongLi commented 1 month ago

Thanks for the information! Looking forward to the complete speculative decoding support!

Thanks for your reply!

caddfa31434 commented 1 month ago

I noticed there's a feature request related to Medusa/Eagle at https://github.com/vllm-project/vllm/issues/4669