pytorch-labs / gpt-fast

Simple and efficient pytorch-native transformer text generation in <1000 LOC of python.
BSD 3-Clause "New" or "Revised" License
5.64k stars 512 forks source link

Speculative decoding slows model down, possibly from "skipping cudagraphs due to ['mutated inputs']"? #21

Open jamestwhedbee opened 11 months ago

jamestwhedbee commented 11 months ago

Some context

I am using AMD MI100 GPUs and I can get ~33 tokens/second for Llama 2 70B using

time torchrun --standalone --nproc_per_node=8 generate.py --compile
 --checkpoint_path checkpoints/Llama-2-70b-chat-hf/model_int8.pth --max_new_tokens 100

Issue

This drops to ~7 tokens/second when I try to include speculative decoding with this command

time torchrun --standalone --nproc_per_node=8 generate.py --compile --draft_checkpoint_path checkpoints/meta-llama/Llama-2-7b-chat-hf/model_int8.pth --checkpoint_path checkpoints/meta-llama/Llama-2-70b-chat-hf/model_int8.pth --max_new_tokens 100 --speculate_k 5

The only clue I have to go on here is this output during compilation (that I only see when using speculative decoding) skipping cudagraphs due to ['mutated inputs']

Chillee commented 11 months ago

Oh sorry, this is a note I should add to the README. This repo currently cannot efficiently support using an int8 quantized model as the verifier model. Basically, Inductor can only efficiently codegen int8 dequant + mm for BS=1, but if you're using speculative decoding it's running with a larger batch size than that.

By the way, I wasn't actually aware this repo worked with AMD GPUs and tensor parallelism. Did you make any code changes to support that?

jamestwhedbee commented 11 months ago

Yep tensor parallelism worked for me with no code changes! I'll try using Llama 2 70b unquantized tomorrow as the verifier model. Because int4 quantization is not supported for AMD GPU in this repo correct? I couldn't get it to work.

https://github.com/pytorch-labs/gpt-fast/issues/12#issuecomment-1836765086

jamestwhedbee commented 11 months ago

@Chillee Maybe I misunderstood, could you give me an example command you think should result in a speed-up?

I can get ~15 tokens/second for an unquantized LLama 70B using compile and tensor parallelism of 8 but that drops to ~8 tokens/second if I try to use speculative decoding.

time torchrun --standalone --nproc_per_node=8 generate.py --compile --draft_checkpoint_path checkpoints/meta-llama/Llama-2-7b-chat-hf/model_int8.pth --checkpoint_path checkpoints/meta-llama/Llama-2-70b-chat-hf/model.pth --max_new_tokens 100 --speculate_k 5
Chillee commented 11 months ago

@jamestwhedbee There are a couple of scripts in the scripts that should result in speedups. In particular, you should try ./scripts/speculate_tp_70B_bf16.sh.

EDIT: There seems to be some kind of issue right now for TP + speculative decoding - I'll take a look.

jamestwhedbee commented 11 months ago

@Chillee that unfortunately also just results in ~8 tokens/second

EDIT: just saw your edit

jamestwhedbee commented 11 months ago

Hey, @Chillee were you able to learn more about the issue here?

yafehlis commented 7 months ago

@Chillee @jamestwhedbee I am seeing a similar issue on my side and I work for AMD. I use INT8 for the target model and draft model, speculative decoding gives slow-downs. Have you resolved this issue yet? Thanks, Yao Fehlis (AMD)