Infini-AI-Lab / MagicDec

Breaking Throughput-Latency Trade-off for Long Sequences with Speculative Decoding
Apache License 2.0
60 stars 4 forks source link

MagicDec: Breaking Throughput-Latency Trade-off for Long Context Generation
with Speculative Decoding

Jian Chen*1, Vashisth Tiwari*1, Ranajoy Sadhukhan*1, Zhuoming Chen1, Jinyuan Shi2, Ian En-Hsu Yen2, Beidi Chen1,3
1Carnegie Mellon University 2Moffett AI 3Meta AI (FAIR)
[Paper] | [Blog]


Installation

Environment Set Up

conda create -n magicdec python=3.11
conda activate magicdec
pip install -r requirements.txt
pip install flash-attn --no-build-isolation
pip install torch==2.5.0.dev20240813+cu121 --index-url https://download.pytorch.org/whl/nightly/cu121/

Prepare Checkpoints

Currently, we support Llama-2-7b and its long context variant Llama-2-7b-32k, Llama-2-13b, Llama-2-70b, Llama-3-8b, Llama-3-70b, Llama-3.1-8b, llama-68m and TinyLlama.

We can first download the checkpoints we need through download.py. The --repo_id should be set to the repository ID to download from. The --hf_token should be your HuggingFace API token. The --out_dir should be the directory you want to save the checkpoint.

python download.py --repo_id meta-llama/Meta-Llama-3.1-8B --hf_token "YOUR HUGGINGFACE API TOKEN" --out_dir checkpoints/meta-llama/Meta-Llama-3.1-8B

Then we need to convert the downloaded checkpoint. --checkpoint_dir should be set to the directory we just saved the checkpoint. This script will generate a new model.pth file in the configured directory.

python convert_hf_checkpoint.py --checkpoint_dir checkpoints/meta-llama/Meta-Llama-3.1-8B

Evaluations

We conducted all the experiments in the paper on 8xA100, 8xH100 and 8xL40. We used PG-19 as the dataset for all the experiments.

Baseline

We used the new one-shot and two-shot all reduce of Pytorch nightly by setting ENABLE_INTRA_NODE_COMM=1. --nproc_per_node should be set to the number of GPUs you want to do tensor parallelism. --model should be set to the directory of the model.pth file of the checkpoint we want to serve. --model_name should be set to the repo id of the checkpoint. --rank_group should be set to the list of GPU id in tensor parallelism. --B is the batch size, --prefix_len is the prefill length, --gen_len is the number of tokens we want to generate for each sentence. --printoutput is the flag which decides whether or not print the output after generation finishes. --compile is the flag to decide whether or not use torch.compile to accelerate the generation.

ENABLE_INTRA_NODE_COMM=1 torchrun --standalone --nproc_per_node=8 tests/baseline_benchmark.py --model checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --model_name meta-llama/Meta-Llama-3.1-8B --rank_group 0 1 2 3 4 5 6 7 --B 1 --prefix_len 4000 --gen_len 64 --printoutput --compile

Standalone Draft

For standalone draft experiment, we use --target and --model to set the target and draft checkpoint. --model_name should be set to the repo id of target model, which will used to load the corresponding tokenizer. --rank_group should be set to the GPU id we want to do tensor parallelism for the target model, and --draft_group should be set to the GPU id we want to do TP for the draft model. Here as Tinyllama 1.1b model only has 4 KV heads, so we can only use 4 GPUs to do TP for it. --streamingllm_budget should be set to the KV budget for the draft model.

ENABLE_INTRA_NODE_COMM=1 torchrun --standalone --nproc_per_node=8 tests/longspec_benchmark.py --target checkpoints/togethercomputer/LLaMA-2-7B-32K/model.pth --model checkpoints/TinyLlama/TinyLlama_v1.1/model.pth --model_name togethercomputer/LLaMA-2-7B-32K --rank_group 0 1 2 3 4 5 6 7 --draft_ranks 0 1 2 3 --gamma 3 --B 64 --prefix_len 16000 --gen_len 64 --streamingllm_budget 256 --benchmark --compile

Self-Speculation

Similar to the standalone draft, but here we do not need to configure the draft model as it is the target model itself.

ENABLE_INTRA_NODE_COMM=1 torchrun --standalone --nproc_per_node=8 tests/selfspec_benchmark.py --model checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --model_name meta-llama/Meta-Llama-3.1-8B --rank_group 0 1 2 3 4 5 6 7 --gamma 3 --B 64 --prefix_len 16000 --gen_len 64 --streamingllm_budget 256 --benchmark --compile

Environment Issue

We discovered that installing Flash-Attention directly with the PyTorch nightly build causes performance issues. However, these issues are resolved if we first install PyTorch 2.4.0 along with Flash-Attention, and then upgrade to the nightly version of PyTorch. We have adopted this approach. We anticipate that these problems will be addressed with the release of PyTorch 2.5.0 and the officially supported version of Flash-Attention.

Citation

If you find MagicDec useful or relevant to your project and research, please kindly cite our paper:

@article{chen2024magicdec,
  title={MagicDec: Breaking the Latency-Throughput Tradeoff for Long Context Generation with Speculative Decoding},
  author={Chen, Jian and Tiwari, Vashisth and Sadhukhan, Ranajoy and Chen, Zhuoming and Shi, Jinyuan and Yen, Ian En-Hsu and Chen, Beidi},
  journal={arXiv preprint arXiv:2408.11049},
  year={2024}
}