pytorch-labs / gpt-fast

Simple and efficient pytorch-native transformer text generation in <1000 LOC of python.
BSD 3-Clause "New" or "Revised" License
5.34k stars 484 forks source link

Speculative decoding slows model down, possibly from "skipping cudagraphs due to ['mutated inputs']"? #21

Open jamestwhedbee opened 7 months ago

jamestwhedbee commented 7 months ago

Some context

I am using AMD MI100 GPUs and I can get ~33 tokens/second for Llama 2 70B using

time torchrun --standalone --nproc_per_node=8 generate.py --compile
 --checkpoint_path checkpoints/Llama-2-70b-chat-hf/model_int8.pth --max_new_tokens 100

Issue

This drops to ~7 tokens/second when I try to include speculative decoding with this command

time torchrun --standalone --nproc_per_node=8 generate.py --compile --draft_checkpoint_path checkpoints/meta-llama/Llama-2-7b-chat-hf/model_int8.pth --checkpoint_path checkpoints/meta-llama/Llama-2-70b-chat-hf/model_int8.pth --max_new_tokens 100 --speculate_k 5

The only clue I have to go on here is this output during compilation (that I only see when using speculative decoding) skipping cudagraphs due to ['mutated inputs']

Chillee commented 7 months ago

Oh sorry, this is a note I should add to the README. This repo currently cannot efficiently support using an int8 quantized model as the verifier model. Basically, Inductor can only efficiently codegen int8 dequant + mm for BS=1, but if you're using speculative decoding it's running with a larger batch size than that.

By the way, I wasn't actually aware this repo worked with AMD GPUs and tensor parallelism. Did you make any code changes to support that?

jamestwhedbee commented 7 months ago

Yep tensor parallelism worked for me with no code changes! I'll try using Llama 2 70b unquantized tomorrow as the verifier model. Because int4 quantization is not supported for AMD GPU in this repo correct? I couldn't get it to work.

https://github.com/pytorch-labs/gpt-fast/issues/12#issuecomment-1836765086

jamestwhedbee commented 7 months ago

@Chillee Maybe I misunderstood, could you give me an example command you think should result in a speed-up?

I can get ~15 tokens/second for an unquantized LLama 70B using compile and tensor parallelism of 8 but that drops to ~8 tokens/second if I try to use speculative decoding.

time torchrun --standalone --nproc_per_node=8 generate.py --compile --draft_checkpoint_path checkpoints/meta-llama/Llama-2-7b-chat-hf/model_int8.pth --checkpoint_path checkpoints/meta-llama/Llama-2-70b-chat-hf/model.pth --max_new_tokens 100 --speculate_k 5
Chillee commented 7 months ago

@jamestwhedbee There are a couple of scripts in the scripts that should result in speedups. In particular, you should try ./scripts/speculate_tp_70B_bf16.sh.

EDIT: There seems to be some kind of issue right now for TP + speculative decoding - I'll take a look.

jamestwhedbee commented 7 months ago

@Chillee that unfortunately also just results in ~8 tokens/second

EDIT: just saw your edit

jamestwhedbee commented 6 months ago

Hey, @Chillee were you able to learn more about the issue here?

yafehlis commented 3 months ago

@Chillee @jamestwhedbee I am seeing a similar issue on my side and I work for AMD. I use INT8 for the target model and draft model, speculative decoding gives slow-downs. Have you resolved this issue yet? Thanks, Yao Fehlis (AMD)