vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
21.83k stars 3.08k forks source link

[Feature]: MLPSpeculator Tensor Parallel support #5809

Open njhill opened 4 days ago

njhill commented 4 days ago

🚀 The feature, motivation and pitch

MLPSpeculator-based speculative decoding was recently added in https://github.com/vllm-project/vllm/pull/4947, but the initial integration only covers single GPU usage.

There will soon be "speculator" models available for larger target models that require multiple GPUs so we would like to ensure that TP can be used.

The first part of this issue would be testing it out in conjunction with https://github.com/vllm-project/vllm/pull/5414 and making necessary adjustments so that it will work with TP=1 for the speculator and TP=N for the target model.

Following this we can look at having the speculator itself run with TP>1, but that may be more involved since it will require some distributed coordination of the sampling of each speculated token in the MLPSpeculator loop. It might be possible to avoid additional communication here by the having the sampler used by the speculator model use a dedicated torch.Generator for its sampling and doing this sampling in tandem across the ranks.

@JRosenkranz already used VocabParallelEmbedding in the implementation so the model layers themselves should work fine.

cc @cadedaniel @sirejdua @JRosenkranz @tdoublep

cadedaniel commented 4 days ago

initial thought:

njhill commented 4 days ago

we can start today with a small model (don't have to wait for new MLPSpeculator), the result should generalize to larger target models.

Yes sorry I should have made that clear, the large models are more the motivation but it can be developed/tested with existing ones.

sirejdua commented 3 days ago

Thanks for writing this up @njhill , I can start working on it.