huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
132.39k stars 26.36k forks source link

Enable speculative decoding with batch size >1 #32165

Open kamilakesbi opened 2 months ago

kamilakesbi commented 2 months ago

Feature request

Speculative decoding isn't currently enabled for batch sizes >1. PR #26875 was previously open to add this feature to main, but never merged. As the PR is quite old and has been closed, I'm opening an issue to motivate the addition of this feature to Transformers.

Two approaches can be implemented to enable speculative decoding with batch size >1:

  1. First approach: at each step, we compare the token ids obtained with the assistant model with those obtained with the main model for all sequences in the batch, and roll back to the first incorrect token id in the batch.

This is a simple approach, but rather naive, since some valid tokens would have to be regenerated during the successive iterations of the assistant model.

  1. Second approach: At each step, we dynamically roll back to the first mismatching token ID of each sequence in the batch, correcting with padding tokens and attention masks.

In this way, at each step, we would keep all the valid tokens and wouldn't need to regenerate them during future iterations of the assistant model.

The second approach is better, and PR #26875 already implements most of the solution, so we should focus on that one IMO.

How to reproduce:

from transformers import pipeline, AutoModelForCausalLM, AutoModelForSpeechSeq2Seq, AutoProcessor
import torch
from datasets import load_dataset

device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

assistant_model_id = "distil-whisper/distil-large-v2"

assistant_model = AutoModelForCausalLM.from_pretrained(
    assistant_model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
assistant_model.to(device)

model_id = "openai/whisper-large-v2"

model = AutoModelForSpeechSeq2Seq.from_pretrained(
    model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)

processor = AutoProcessor.from_pretrained(model_id)

pipe = pipeline(
    "automatic-speech-recognition",
    model=model,
    tokenizer=processor.tokenizer,
    feature_extractor=processor.feature_extractor,
    max_new_tokens=128,
    generate_kwargs={"assistant_model": assistant_model},
    torch_dtype=torch_dtype,
    chunk_length_s=15,
    batch_size=4,
    device=device,
)

dataset = load_dataset("distil-whisper/librispeech_long", "default", split="validation")
sample = dataset[0]["audio"]

result = pipe(sample)
print(result["text"])

Curent output:

ValueError: assisted generate is only supported for batch_size = 1

Expected output:

" Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel. Nor is Mr. Quilter's manner less interesting than his manner. He tells us that at this festive season of the year with Christmas and roast beef looming before us, similes drawn from eating and its results occur most readily to the mind. He has grave doubts whether Sir Frederick Leighton's work is really Greek after all, and can discover in it but little of rocky Ithaca. Linnell's pictures are a sort of Upguards and Adam paintings, and Mason's exquisite idylls are as national as a jingo poem. Mr. Burkett Foster's landscapes smile at one much in the same way that Mr. Carker used to flash his teeth, and Mr. John Collier gives his sitter a cheerful slap on the back, before he says, like a shampoo-er in a Turkish bath, Next man!"

Your contribution

I started iterating on the solution and will open a PR soon to solve it :)

cc @sanchit-gandhi @ylacombe @gante

zucchini-nlp commented 2 months ago

Cool, looking forward for the PR! The last time a contributor worked on it, there were issues with latency due to excessive padding (https://github.com/huggingface/transformers/issues/29769), hope your PR solves them :)