huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
128.75k stars 25.54k forks source link

Speculative Decoding Snippet Not Working #29869

Closed hieunguyenquoc closed 1 month ago

hieunguyenquoc commented 3 months ago

System Info

transformers==4.39.1 python==3.8.17 torch==2.0.1+cpu

Who can help?

@sanchit-gandhi

Information

Tasks

Reproduction

from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
import torch
from operator import itemgetter
import time
print(torch.__version__)

class PhoWhisper_Finetune_Model:
      def __init__(self) -> None:
          self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
          self.torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
          self.MODEL_ID = "phowhisper_medium_finetuned"
          self.model = AutoModelForSpeechSeq2Seq.from_pretrained(
              self.MODEL_ID, torch_dtype=self.torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
          )
          self.model.to(self.device)
          self.processor = AutoProcessor.from_pretrained("vinai/PhoWhisper-medium")

    def infer(self,audiopath:str) -> str:
        assistant_model_id = "vinai/PhoWhisper-tiny"
        assistant_model = AutoModelForSpeechSeq2Seq.from_pretrained(
            assistant_model_id,
            torch_dtype=self.torch_dtype,
            low_cpu_mem_usage=True
        ) 
        assistant_model.to(self.device)

        pipe = pipeline(
        "automatic-speech-recognition",
            model=self.model,
            tokenizer=self.processor.tokenizer,
            feature_extractor=self.processor.feature_extractor,
            chunk_length_s=20,
            batch_size=16,
            return_timestamps=True,
            torch_dtype=self.torch_dtype,
            device=self.device,
            generate_kwargs={"task": "transcribe","language":'vi', "assistant_model": assistant_model} 
        )
        prediction = pipe(audiopath)
        result_string = " ".join(map(itemgetter('text'), prediction["chunks"]))
        return result_string

if __name__ == "__main__":
        whisper = PhoWhisper_Finetune_Model()
        start = time.time()
        result = whisper.infer("-184354569133200865_104448_105080.wav")
        print(result)
        print("Time :",time.time() - start)

Expected behavior

I have tried Speculative Decoding on two version of PhoWhisper (Whisper finetuned version) followed by this post [https://huggingface.co/blog/whisper-speculative-decoding]. I have this error : ValueError: Whisper expects the mel input features to be of length 3000, but found 1500. Make sure to pad the input mel features to 3000. Could you help me ? Thank you @sanchit-gandhi

amyeroberts commented 2 months ago

Gentle ping @sanchit-gandhi

jdvin commented 2 months ago

I believe the problem lies here:

https://github.com/huggingface/transformers/blob/73014b561d5f88d728e46a57d346f516fefe3f2d/src/transformers/generation/candidate_generator.py#L119-L129

The check for if "encoder_outputs" in model_kwargs on line 128 should be above the check for if assistant_model.config.is_encoder_decoder on line 121, because otherwise the outputs of the main model's encoder are fed in as inputs to the encoder of the assistant model, when they should be just used as inputs for the assistant decoder.

Happy to submit a PR for this.

sanchit-gandhi commented 2 months ago

Thanks for reporting @hieunguyenquoc! A PR would be most welcome @jdvin if you have the bandwidth, otherwise cc @kamilakesbi if you could take a look

kamilakesbi commented 1 month ago

This issue has been solved with PR #30637 :)