elixir-nx / bumblebee

Pre-trained Neural Network models in Axon (+ 🤗 Models integration)
Apache License 2.0
1.27k stars 90 forks source link

Cannot use `whisper-*.en` models in bumblebee #312

Closed John-Goff closed 6 months ago

John-Goff commented 6 months ago

With any combination of whisper-medium.en, whisper-small.en, or whisper-tiny.en the following code does not work:

Mix.install([
  {:bumblebee, "~> 0.4.2"},
  {:nx, "~> 0.6.4"},
  {:exla, "~> 0.6.4"},
  {:axon, "~> 0.6.0"}
])

{[{:file, path}], _, _} =
  OptionParser.parse(System.argv(), strict: [file: :string], aliases: [f: :file])

{:ok, model_info} = Bumblebee.load_model({:hf, "openai/whisper-tiny.en"})
{:ok, featurizer} = Bumblebee.load_featurizer({:hf, "openai/whisper-tiny.en"})
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "openai/whisper-tiny.en"})
{:ok, generation_config} = Bumblebee.load_generation_config({:hf, "openai/whisper-tiny.en"})
generation_config = Bumblebee.configure(generation_config, max_new_tokens: 100)

serving =
  Bumblebee.Audio.speech_to_text_whisper(model_info, featurizer, tokenizer, generation_config,
    chunk_num_seconds: 30,
    timestamps: :segments,
    stream: true,
    compile: [batch_size: 4],
    defn_options: [compiler: EXLA]
  )

Nx.Serving.run(serving, {:file, Path.expand(path)})
|> Enum.reduce("", fn chunk, acc -> acc <> chunk.text end)
|> IO.puts()

Fails with error

** (RuntimeError) invalid task :transcribe, expected one of: 
    (bumblebee 0.4.2) lib/bumblebee/audio/speech_to_text_whisper.ex:210: Bumblebee.Audio.SpeechToTextWhisper.forced_token_ids/2
    (bumblebee 0.4.2) lib/bumblebee/audio/speech_to_text_whisper.ex:156: Bumblebee.Audio.SpeechToTextWhisper.generate_opts/2
    (bumblebee 0.4.2) lib/bumblebee/audio/speech_to_text_whisper.ex:50: Bumblebee.Audio.SpeechToTextWhisper.speech_to_text_whisper/5
    transcribe.exs:18: (file)

Changing the model to whisper-medium, whisper-small, or whisper-tiny will work properly.

John-Goff commented 6 months ago

distil-whisper/distil-small.en and related models also fails with the same error.

jonatanklosko commented 6 months ago

For these you need to set task: nil explicitly. Bumblebee main has an improved error message, it just hasn't been released yet :)

the generation config does not have any tasks defined. If you are dealing with a monolingual model, set :task to nil. Otherwise you may need to update generation_config.extra_config.task_to_token_id
grzuy commented 6 months ago

For the record, discussed previously in https://github.com/elixir-nx/bumblebee/issues/267.