Open njhill opened 4 days ago
initial thought:
we can start today with a small model (don't have to wait for new MLPSpeculator), the result should generalize to larger target models.
Yes sorry I should have made that clear, the large models are more the motivation but it can be developed/tested with existing ones.
Thanks for writing this up @njhill , I can start working on it.
🚀 The feature, motivation and pitch
MLPSpeculator
-based speculative decoding was recently added in https://github.com/vllm-project/vllm/pull/4947, but the initial integration only covers single GPU usage.There will soon be "speculator" models available for larger target models that require multiple GPUs so we would like to ensure that TP can be used.
The first part of this issue would be testing it out in conjunction with https://github.com/vllm-project/vllm/pull/5414 and making necessary adjustments so that it will work with TP=1 for the speculator and TP=N for the target model.
Following this we can look at having the speculator itself run with TP>1, but that may be more involved since it will require some distributed coordination of the sampling of each speculated token in the MLPSpeculator loop. It might be possible to avoid additional communication here by the having the sampler used by the speculator model use a dedicated
torch.Generator
for its sampling and doing this sampling in tandem across the ranks.@JRosenkranz already used
VocabParallelEmbedding
in the implementation so the model layers themselves should work fine.cc @cadedaniel @sirejdua @JRosenkranz @tdoublep