jianfch / stable-ts

Transcription, forced alignment, and audio indexing with OpenAI's Whisper
MIT License
1.59k stars 176 forks source link

load_hf_whisper fails on mps backend on M2 chip #307

Closed Exr0n closed 9 months ago

Exr0n commented 9 months ago

When using the example

import stable_whisper
model = stable_whisper.load_hf_whisper('base')
result = model.transcribe('toby_cookies.m4a')
result.to_srt_vtt('audio.srt')

I get the error

Traceback (most recent call last):
  File "/Users/exr0n/projects/whispernouns/asr/phonemes/main.py", line 2, in <module>
    model = stable_whisper.load_hf_whisper('base')
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/lib/python3.11/site-packages/stable_whisper/whisper_word_level/hf_whisper.py", line 249, in load_hf_whisper
    return WhisperHF(model_name, device, flash=flash)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/lib/python3.11/site-packages/stable_whisper/whisper_word_level/hf_whisper.py", line 78, in __init__
    self._pipe = load_hf_pipe(self._model_name, device, flash=flash)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/lib/python3.11/site-packages/stable_whisper/whisper_word_level/hf_whisper.py", line 50, in load_hf_pipe
    ).to(device)
      ^^^^^^^^^^
  File "/opt/homebrew/lib/python3.11/site-packages/transformers/modeling_utils.py", line 2595, in to
    return super().to(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1145, in to
    raise TypeError('nn.Module.to only accepts floating point or complex '
TypeError: nn.Module.to only accepts floating point or complex dtypes, but got desired dtype=torch.bool

It looks like this is because hf_whisper.py returns True instead of 'mps'.

McCloudS commented 9 months ago

For anyone else, in the meantime, you can directly choose the device as: model = stable_whisper.load_hf_whisper('base', device='mps')

jianfch commented 9 months ago

@Exr0n Thanks for pointing this out. It should be fixed in 53272cb0376b2c1e32f0055a5505045681409aac.

Exr0n commented 9 months ago

Thanks for the quick response!