aiola-lab / whisper-medusa

Whisper with Medusa heads
MIT License
800 stars 49 forks source link

The sample code in the README is not working #3

Closed adi closed 3 months ago

adi commented 3 months ago
from whisper_medusa import WhisperMedusa
model = WhisperMedusa.from_pretrained("aiola/whisper-medusa-v1")
model_output = model.generate(
    input_features,
    language=language,
)
predict_ids = model_output[0]

gives

---------------------------------------------------------------------------
ImportError                               Traceback (most recent call last)
Cell In[1], line 1
----> 1 from whisper_medusa import WhisperMedusa
      2 model = WhisperMedusa.from_pretrained("aiola/whisper-medusa-v1")
      3 model_output = model.generate(
      4     input_features,
      5     language=language,
      6 )

ImportError: cannot import name 'WhisperMedusa' from 'whisper_medusa'

checking the whisper-medusa/whisper_medusa/__init__.py file I can see it's empty.

AvivNavon commented 3 months ago

Hi @adi, thanks for spotting this. This PR #4 fix the issue and provide a full generation example on the readme.

AvivNavon commented 3 months ago

Merged to main, please try again with the updated example code

adi commented 3 months ago

I think it should be

from whisper_medusa.models import WhisperMedusaModel

instead of

from whisper_medusa import WhisperMedusaModel
AvivNavon commented 3 months ago

It should be fine now https://github.com/aiola-lab/whisper-medusa/blob/main/whisper_medusa/__init__.py#L1 Let me know if it's not working for you

adi commented 3 months ago

Yes. Now everything works fine. And very very fast. Congrats guys!