elixir-nx / bumblebee

Pre-trained Neural Network models in Axon (+ šŸ¤— Models integration)
Apache License 2.0
1.26k stars 90 forks source link

Running Whisper using bf16 fails #345

Closed costaraphael closed 4 months ago

costaraphael commented 4 months ago

Hey folks!

I'm working on optimizing a deployment of whisper-large-v3 by moving to float16/bfloat16 instead of the default float32. The problem is that the cross-attention/cross-atention cache layers are initialized using float32 hardcoded, causing the compilation to fail when starting the serving.

Here's an example in a Livebook using whisper-tiny (same behavior):

Mix.install([
  {:bumblebee, github: "elixir-nx/bumblebee"},
  {:nx, github: "elixir-nx/nx", sparse: "nx", override: true},
  {:exla, github: "elixir-nx/nx", sparse: "exla", override: true},
  {:kino, "~> 0.12.3"}
])

hf_model = "openai/whisper-tiny"

{:ok, model} = Bumblebee.load_model({:hf, hf_model}, type: :bf16)
{:ok, featurizer} = Bumblebee.load_featurizer({:hf, hf_model})
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, hf_model})
{:ok, generation_config} = Bumblebee.load_generation_config({:hf, hf_model})
generation_config = Bumblebee.configure(generation_config, max_new_tokens: 128)

serving =
  Bumblebee.Audio.speech_to_text_whisper(model, featurizer, tokenizer, generation_config,
    chunk_num_seconds: 30,
    compile: [batch_size: 16],
    defn_options: [compiler: EXLA],
    preallocate_params: true
  )

Nx.Serving.run(serving, {:file, Kino.FS.file_path("sample_audio.wav")})
Error stacktrace ``` ** (CompileError) Library/Caches/mix/installs/elixir-1.16.1-erts-14.2/8f5b0d8afb97a8f884e160e25fa6d1f9/deps/bumblebee/lib/bumblebee/text/generation.ex:479: the do-block in while must return tensors with the same shape, type, and names as the initial arguments. {%{length: #Nx.Tensor< s64 >, ignored: #Nx.Tensor< s64[16] >, input_length: #Nx.Tensor< s64 >, sequences: #Nx.Tensor< s64[16][129] >, finished_length: #Nx.Tensor< s64[16] >}, %{"cache" => %{offset: #Nx.Tensor< s64 >, blocks: {%{self_attention: %{value: <<<<< Body (do-block) <<<<< #Nx.Tensor< bf16[16][129][6][64] > ========== #Nx.Tensor< f32[16][129][6][64] > >>>>> Initial >>>>> , key: <<<<< Body (do-block) <<<<< #Nx.Tensor< bf16[16][129][6][64] > ========== #Nx.Tensor< f32[16][129][6][64] > >>>>> Initial >>>>> }, cross_attention: %{value: <<<<< Body (do-block) <<<<< #Nx.Tensor< bf16[16][1500][6][64] > ========== #Nx.Tensor< f32[16][1500][6][64] > >>>>> Initial >>>>> , key: <<<<< Body (do-block) <<<<< #Nx.Tensor< bf16[16][1500][6][64] > ========== #Nx.Tensor< f32[16][1500][6][64] > >>>>> Initial >>>>> }}, %{self_attention: %{value: <<<<< Body (do-block) <<<<< #Nx.Tensor< bf16[16][129][6][64] > ========== #Nx.Tensor< f32[16][129][6][64] > >>>>> Initial >>>>> , key: <<<<< Body (do-block) <<<<< #Nx.Tensor< bf16[16][129][6][64] > ========== #Nx.Tensor< f32[16][129][6][64] > >>>>> Initial >>>>> }, cross_attention: %{value: <<<<< Body (do-block) <<<<< #Nx.Tensor< bf16[16][1500][6][64] > ========== #Nx.Tensor< f32[16][1500][6][64] > >>>>> Initial >>>>> , key: <<<<< Body (do-block) <<<<< #Nx.Tensor< bf16[16][1500][6][64] > ========== #Nx.Tensor< f32[16][1500][6][64] > >>>>> Initial >>>>> }}, %{self_attention: %{value: <<<<< Body (do-block) <<<<< #Nx.Tensor< bf16[16][129][6][64] > ========== #Nx.Tensor< f32[16][129][6][64] > >>>>> Initial >>>>> , key: <<<<< Body (do-block) <<<<< #Nx.Tensor< bf16[16][129][6][64] > ========== #Nx.Tensor< f32[16][129][6][64] > >>>>> Initial >>>>> }, cross_attention: %{value: <<<<< Body (do-block) <<<<< #Nx.Tensor< bf16[16][1500][6][64] > ========== #Nx.Tensor< f32[16][1500][6][64] > >>>>> Initial >>>>> , key: <<<<< Body (do-block) <<<<< #Nx.Tensor< bf16[16][1500][6][64] > ========== #Nx.Tensor< f32[16][1500][6][64] > >>>>> Initial >>>>> }}, %{self_attention: %{value: <<<<< Body (do-block) <<<<< #Nx.Tensor< bf16[16][129][6][64] > ========== #Nx.Tensor< f32[16][129][6][64] > >>>>> Initial >>>>> , key: <<<<< Body (do-block) <<<<< #Nx.Tensor< bf16[16][129][6][64] > ========== #Nx.Tensor< f32[16][129][6][64] > >>>>> Initial >>>>> }, cross_attention: %{value: <<<<< Body (do-block) <<<<< #Nx.Tensor< bf16[16][1500][6][64] > ========== #Nx.Tensor< f32[16][1500][6][64] > >>>>> Initial >>>>> , key: <<<<< Body (do-block) <<<<< #Nx.Tensor< bf16[16][1500][6][64] > ========== #Nx.Tensor< f32[16][1500][6][64] > >>>>> Initial >>>>> }}}, attention_mask: #Nx.Tensor< s64[16][129] >}, "decoder_attention_mask" => #Nx.Tensor< s64[16][1] >, "decoder_input_ids" => #Nx.Tensor< s64[16][1] >, "decoder_position_ids" => #Nx.Tensor< s64[16][1] >, "encoder_hidden_state" => #Nx.Tensor< f32[batch: 16][frames: 1500][mel: 384] >, "input_features" => #Nx.Tensor< f32[batch: 16][frames: 3000][mel: 80] >}, %{"decoder.blocks.0.ffn.output" => %{"bias" => #Nx.Tensor< bf16[384] >, "kernel" => #Nx.Tensor< bf16[1536][384] >}, "encoder.blocks.0.self_attention_norm" => %{"beta" => #Nx.Tensor< bf16[384] >, "gamma" => #Nx.Tensor< bf16[384] >}, "decoder.blocks.2.self_attention.output" => %{"bias" => #Nx.Tensor< bf16[384] >, "kernel" => #Nx.Tensor< bf16[384][384] >}, "decoder.blocks.3.self_attention.value" => %{"bias" => #Nx.Tensor< bf16[384] >, "kernel" => #Nx.Tensor< bf16[384][384] >}, "decoder.blocks.3.self_attention.output" => %{"bias" => #Nx.Tensor< bf16[384] >, "kernel" => #Nx.Tensor< bf16[384][384] >}, "decoder_embedder.position_embedding" => %{"embeddings" => #Nx.Tensor< bf16[448][384] >}, "decoder.blocks.2.self_attention.value" => %{"bias" => #Nx.Tensor< bf16[384] >, "kernel" => #Nx.Tensor< bf16[384][384] >}, "encoder.blocks.1.self_attention.output" => %{"bias" => #Nx.Tensor< bf16[384] >, "kernel" => #Nx.Tensor< bf16[384][384] >}, "decoder.blocks.2.ffn.output" => %{"bias" => #Nx.Tensor< bf16[384] >, "kernel" => #Nx.Tensor< bf16[1536][384] >}, "decoder.blocks.3.self_attention.key" => %{"kernel" => #Nx.Tensor< bf16[384][384] >}, "encoder.blocks.1.output_norm" => %{"beta" => #Nx.Tensor< bf16[384] >, "gamma" => #Nx.Tensor< bf16[384] >}, "encoder.blocks.1.self_attention_norm" => %{"beta" => #Nx.Tensor< bf16[384] >, "gamma" => #Nx.Tensor< bf16[384] >}, "encoder.blocks.1.ffn.output" => %{"bias" => #Nx.Tensor< bf16[384] >, "kernel" => #Nx.Tensor< bf16[1536][384] >}, "encoder.blocks.2.self_attention.key" => %{"kernel" => #Nx.Tensor< bf16[384][384] >}, "encoder.blocks.3.ffn.intermediate" => %{"bias" => #Nx.Tensor< bf16[1536] >, "kernel" => #Nx.Tensor< bf16[384][1536] >}, "decoder.blocks.3.output_norm" => %{"beta" => #Nx.Tensor< bf16[384] >, "gamma" => #Nx.Tensor< bf16[384] >}, "decoder.blocks.3.cross_attention.query" => %{"bias" => #Nx.Tensor< bf16[384] >, "kernel" => #Nx.Tensor< bf16[384][384] >}, "encoder.norm" => %{"beta" => #Nx.Tensor< bf16[384] >, "gamma" => #Nx.Tensor< bf16[384] >}, "decoder.blocks.0.self_attention.output" => %{"bias" => #Nx.Tensor< bf16[384] >, "kernel" => #Nx.Tensor< bf16[384][384] >}, "decoder.blocks.2.cross_attention.output" => %{"bias" => #Nx.Tensor< bf16[384] >, "kernel" => #Nx.Tensor< bf16[384][384] >}, "encoder.blocks.3.ffn.output" => %{"bias" => #Nx.Tensor< bf16[384] >, "kernel" => #Nx.Tensor< bf16[1536][384] >}, "decoder.blocks.3.ffn.output" => %{"bias" => #Nx.Tensor< bf16[384] >, "kernel" => #Nx.Tensor< bf16[1536][384] >}, "encoder_embedder.feature_embedding.conv_1" => %{"bias" => #Nx.Tensor< bf16[384] >, "kernel" => #Nx.Tensor< bf16[3][80][384] >}, "encoder.blocks.0.self_attention.query" => %{"bias" => #Nx.Tensor< bf16[384] >, "kernel" => #Nx.Tensor< bf16[384][384] >}, "encoder.blocks.3.self_attention.query" => %{"bias" => #Nx.Tensor< bf16[384] >, "kernel" => #Nx.Tensor< bf16[384][384] >}, "decoder.blocks.2.cross_attention.query" => %{"bias" => #Nx.Tensor< bf16[384] >, "kernel" => #Nx.Tensor< bf16[384][384] >}, "encoder.blocks.2.output_norm" => %{"beta" => #Nx.Tensor< bf16[384] >, "gamma" => #Nx.Tensor< bf16[384] >}, "encoder.blocks.0.ffn.intermediate" => %{"bias" => #Nx.Tensor< bf16[1536] >, "kernel" => #Nx.Tensor< bf16[384][1536] >}, "decoder.blocks.2.cross_attention_norm" => %{"beta" => #Nx.Tensor< bf16[384] >, "gamma" => #Nx.Tensor< bf16[384] >}, "decoder.blocks.1.self_attention.value" => %{"bias" => #Nx.Tensor< bf16[384] >, "kernel" => #Nx.Tensor< bf16[384][384] >}, "decoder.blocks.0.output_norm" => %{"beta" => #Nx.Tensor< bf16[384] >, "gamma" => #Nx.Tensor< bf16[384] >}, "encoder.blocks.3.self_attention.value" => %{"bias" => #Nx.Tensor< bf16[384] >, "kernel" => #Nx.Tensor< bf16[384][384] >}, "encoder.blocks.2.ffn.intermediate" => %{"bias" => #Nx.Tensor< bf16[1536] >, "kernel" => #Nx.Tensor< bf16[384][1536] >}, "decoder.blocks.2.self_attention.key" => %{"kernel" => #Nx.Tensor< bf16[384][384] >}, "decoder.blocks.2.output_norm" => %{"beta" => #Nx.Tensor< bf16[384] >, "gamma" => #Nx.Tensor< bf16[384] >}, "decoder.blocks.0.cross_attention.value" => %{"bias" => #Nx.Tensor< bf16[384] >, "kernel" => #Nx.Tensor< bf16[384][384] >}, "decoder.blocks.0.cross_attention.key" => %{"kernel" => #Nx.Tensor< bf16[384][384] >}, "decoder.blocks.1.cross_attention.key" => %{"kernel" => #Nx.Tensor< bf16[384][384] >}, "encoder.blocks.0.self_attention.output" => %{"bias" => #Nx.Tensor< bf16[384] >, "kernel" => #Nx.Tensor< bf16[384][384] >}, "encoder.blocks.0.ffn.output" => %{"bias" => #Nx.Tensor< bf16[384] >, "kernel" => #Nx.Tensor< bf16[1536][384] >}, "encoder.blocks.2.self_attention_norm" => %{"beta" => #Nx.Tensor< bf16[384] >, "gamma" => #Nx.Tensor< bf16[384] >}, "encoder.blocks.2.self_attention.output" => %{"bias" => #Nx.Tensor< bf16[384] >, "kernel" => #Nx.Tensor< bf16[384][384] >}, "decoder.blocks.2.cross_attention.value" => %{"bias" => #Nx.Tensor< bf16[384] >, "kernel" => #Nx.Tensor< bf16[384][384] >}, "decoder.blocks.0.self_attention.value" => %{"bias" => #Nx.Tensor< bf16[384] >, "kernel" => #Nx.Tensor< bf16[384][384] >}, "language_modeling_head.output" => %{"kernel" => #Nx.Tensor< bf16[51865][384] >}, "decoder.blocks.2.self_attention_norm" => %{"beta" => #Nx.Tensor< bf16[384] >, ...}, "decoder.blocks.1.ffn.output" => %{...}, ...}} (nx 0.7.0) lib/nx/defn/expr.ex:759: Nx.Defn.Expr.compatible_while!/4 (nx 0.7.0) lib/nx/defn/expr.ex:520: Nx.Defn.Expr.while_vectorized/7 (bumblebee 0.4.2) lib/bumblebee/text/generation.ex:479: Bumblebee.Text.Generation."__defn:greedy__"/7 (bumblebee 0.4.2) lib/bumblebee/text/generation.ex:390: Bumblebee.Text.Generation."__defn:generate_impl__"/8 (bumblebee 0.4.2) lib/bumblebee/audio/speech_to_text_whisper.ex:55: anonymous fn/4 in Bumblebee.Audio.SpeechToTextWhisper.speech_to_text_whisper/5 (nx 0.7.0) lib/nx/defn/compiler.ex:173: Nx.Defn.Compiler.runtime_fun/3 (exla 0.7.0) lib/exla/defn.ex:551: anonymous fn/4 in EXLA.Defn.compile/8 #cell:luahaswjdkkgzubs:1: (file) ``` (`#cell:luahaswjdkkgzubs:1:` is the `Nx.Serving.run` call)

I can get it working if I do

policy =
  Axon.MixedPrecision.create_policy(
    params: {:bf, 16},
    compute: {:bf, 16},
    output: {:f, 32}
  )

{:ok, model} = Bumblebee.load_model({:hf, hf_model}, type: policy)

which got transcription time of 48mins of audio on a GPU from 240s to 200s.

But when I did a couple tweaks to Bumblebee's codebase, namely:

It allowed me to set the output type to bfloat16 as well, bringing the execution time to 170s. Memory consumption during transcription also fell considerably, but it is hard to measure gains accurately there due to XLA preallocations.

I would have opened a PR with this, the problem is that currently there's no easy way to get the type given to load_model in the init_cache/4 callback šŸ˜…

So I opened this issue to discuss what a proper fix could look like. I'm more than happy to open a PR for it then šŸ˜ƒ

jonatanklosko commented 4 months ago

Hey @costaraphael, good catch! Fixed on main and a release coming very soon :)