Open jamestwhedbee opened 11 months ago
Oh sorry, this is a note I should add to the README. This repo currently cannot efficiently support using an int8 quantized model as the verifier model. Basically, Inductor can only efficiently codegen int8 dequant + mm for BS=1, but if you're using speculative decoding it's running with a larger batch size than that.
By the way, I wasn't actually aware this repo worked with AMD GPUs and tensor parallelism. Did you make any code changes to support that?
Yep tensor parallelism worked for me with no code changes! I'll try using Llama 2 70b unquantized tomorrow as the verifier model. Because int4 quantization is not supported for AMD GPU in this repo correct? I couldn't get it to work.
https://github.com/pytorch-labs/gpt-fast/issues/12#issuecomment-1836765086
@Chillee Maybe I misunderstood, could you give me an example command you think should result in a speed-up?
I can get ~15 tokens/second for an unquantized LLama 70B using compile and tensor parallelism of 8 but that drops to ~8 tokens/second if I try to use speculative decoding.
time torchrun --standalone --nproc_per_node=8 generate.py --compile --draft_checkpoint_path checkpoints/meta-llama/Llama-2-7b-chat-hf/model_int8.pth --checkpoint_path checkpoints/meta-llama/Llama-2-70b-chat-hf/model.pth --max_new_tokens 100 --speculate_k 5
@jamestwhedbee There are a couple of scripts in the scripts
that should result in speedups. In particular, you should try ./scripts/speculate_tp_70B_bf16.sh
.
EDIT: There seems to be some kind of issue right now for TP + speculative decoding - I'll take a look.
@Chillee that unfortunately also just results in ~8 tokens/second
EDIT: just saw your edit
Hey, @Chillee were you able to learn more about the issue here?
@Chillee @jamestwhedbee I am seeing a similar issue on my side and I work for AMD. I use INT8 for the target model and draft model, speculative decoding gives slow-downs. Have you resolved this issue yet? Thanks, Yao Fehlis (AMD)
Some context
I am using AMD MI100 GPUs and I can get ~33 tokens/second for Llama 2 70B using
Issue
This drops to ~7 tokens/second when I try to include speculative decoding with this command
The only clue I have to go on here is this output during compilation (that I only see when using speculative decoding)
skipping cudagraphs due to ['mutated inputs']