YuanGongND / whisper-at

Code and Pretrained Models for Interspeech 2023 Paper "Whisper-AT: Noise-Robust Automatic Speech Recognizers are Also Strong Audio Event Taggers"
BSD 2-Clause "Simplified" License
338 stars 28 forks source link

Exception using word_timestamps=True in model.transcribe #4

Open tallzilla opened 1 year ago

tallzilla commented 1 year ago

Hi there! I was hoping to use whisper's ability to provide timestamps around the audio events your work captures.

I'm currently getting an exception when I pass through a True word_timestamps value to model.transcribe().

Thanks!

import whisper_at as whisper
model = whisper.load_model("small")
result = model.transcribe(audio_path, at_time_res=10, word_timestamps=True)

Traceback

[/usr/local/lib/python3.10/dist-packages/whisper_at/transcribe.py](https://localhost:8080/#) in transcribe(model, audio, verbose, temperature, compression_ratio_threshold, logprob_threshold, no_speech_threshold, condition_on_previous_text, initial_prompt, word_timestamps, prepend_punctuations, append_punctuations, at_time_res, **decode_options)
    344 
    345             if word_timestamps:
--> 346                 add_word_timestamps(
    347                     segments=current_segments,
    348                     model=model,

[/usr/local/lib/python3.10/dist-packages/whisper_at/timing.py](https://localhost:8080/#) in add_word_timestamps(segments, model, tokenizer, mel, num_frames, prepend_punctuations, append_punctuations, **kwargs)
    310 
    311     text_tokens = list(itertools.chain.from_iterable(text_tokens_per_segment))
--> 312     alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **kwargs)
    313     merge_punctuations(alignment, prepend_punctuations, append_punctuations)
    314 

[/usr/local/lib/python3.10/dist-packages/whisper_at/timing.py](https://localhost:8080/#) in find_alignment(model, tokenizer, text_tokens, mel, num_frames, medfilt_width, qk_scale)
    193 
    194     with torch.no_grad():
--> 195         logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0]
    196         sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot]
    197         token_probs = sampled_logits.softmax(dim=-1)

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1499                 or _global_backward_pre_hooks or _global_backward_hooks
   1500                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501             return forward_call(*args, **kwargs)
   1502         # Do not call functions when jit is used
   1503         full_backward_hooks, non_full_backward_hooks = [], []

[/usr/local/lib/python3.10/dist-packages/whisper_at/model.py](https://localhost:8080/#) in forward(self, mel, tokens)
    271         self, mel: torch.Tensor, tokens: torch.Tensor
    272     ) -> Dict[str, torch.Tensor]:
--> 273         return self.decoder(tokens, self.encoder(mel))
    274 
    275     @property

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1499                 or _global_backward_pre_hooks or _global_backward_hooks
   1500                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501             return forward_call(*args, **kwargs)
   1502         # Do not call functions when jit is used
   1503         full_backward_hooks, non_full_backward_hooks = [], []

[/usr/local/lib/python3.10/dist-packages/whisper_at/model.py](https://localhost:8080/#) in forward(self, x, xa, kv_cache)
    210             + self.positional_embedding[offset : offset + x.shape[-1]]
    211         )
--> 212         x = x.to(xa.dtype)
    213 
    214         for block in self.blocks:

AttributeError: 'tuple' object has no attribute 'dtype'
YuanGongND commented 1 year ago

hi there,

This seems to be a bug and will be fixed in the next version.

Unfortunately I am working on another deadline and only have time to do this after 8/10.

-Yuan

congtuong commented 1 year ago

Hi ,I have faced with the same issue and fixed it,

Just change the line 273 in package/whisper-at/whisper_at/model.py to this: return self.decoder(tokens, self.encoder(mel)[0])

tallzilla commented 1 year ago

This change worked for me!

On Sun, Aug 6, 2023 at 10:06 PM congtuong @.***> wrote:

Hi ,I have faced with the same issue and fixed it,

Just change the line 273 in package/whisper-at/whisper_at/model.py to this: return self.decoder(tokens, self.encoder(mel)[0])

— Reply to this email directly, view it on GitHub https://github.com/YuanGongND/whisper-at/issues/4#issuecomment-1667192618, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAM3FEETZON4RGM2TRRVGXTXUBZVXANCNFSM6AAAAAA3AUN2HE . You are receiving this because you authored the thread.Message ID: @.***>

YuanGongND commented 1 year ago

thanks @congtuong and @tallzilla , I will check this and put it in the next version.

-Yuan

YuanGongND commented 1 year ago

thanks! this is fixed in whisper-at==0.5.

-Yuan