I'm working on optimizing a deployment of whisper-large-v3 by moving to float16/bfloat16 instead of the default float32. The problem is that the cross-attention/cross-atention cache layers are initialized using float32 hardcoded, causing the compilation to fail when starting the serving.
Here's an example in a Livebook using whisper-tiny (same behavior):
which got transcription time of 48mins of audio on a GPU from 240s to 200s.
But when I did a couple tweaks to Bumblebee's codebase, namely:
Adding an extra type option to the Bumblebee.Layers.Decoder.init_cache/3 function, and creating the caches using the given type;
Passing type: :bf16 when calling the above function in Bumblebee.Audio.Whisper.init_cache/4
It allowed me to set the output type to bfloat16 as well, bringing the execution time to 170s. Memory consumption during transcription also fell considerably, but it is hard to measure gains accurately there due to XLA preallocations.
I would have opened a PR with this, the problem is that currently there's no easy way to get the type given to load_model in the init_cache/4 callback š
So I opened this issue to discuss what a proper fix could look like. I'm more than happy to open a PR for it then š
Hey folks!
I'm working on optimizing a deployment of
whisper-large-v3
by moving tofloat16
/bfloat16
instead of the defaultfloat32
. The problem is that the cross-attention/cross-atention cache layers are initialized usingfloat32
hardcoded, causing the compilation to fail when starting the serving.Here's an example in a Livebook using
whisper-tiny
(same behavior):Error stacktrace
``` ** (CompileError) Library/Caches/mix/installs/elixir-1.16.1-erts-14.2/8f5b0d8afb97a8f884e160e25fa6d1f9/deps/bumblebee/lib/bumblebee/text/generation.ex:479: the do-block in while must return tensors with the same shape, type, and names as the initial arguments. {%{length: #Nx.Tensor< s64 >, ignored: #Nx.Tensor< s64[16] >, input_length: #Nx.Tensor< s64 >, sequences: #Nx.Tensor< s64[16][129] >, finished_length: #Nx.Tensor< s64[16] >}, %{"cache" => %{offset: #Nx.Tensor< s64 >, blocks: {%{self_attention: %{value: <<<<< Body (do-block) <<<<< #Nx.Tensor< bf16[16][129][6][64] > ========== #Nx.Tensor< f32[16][129][6][64] > >>>>> Initial >>>>> , key: <<<<< Body (do-block) <<<<< #Nx.Tensor< bf16[16][129][6][64] > ========== #Nx.Tensor< f32[16][129][6][64] > >>>>> Initial >>>>> }, cross_attention: %{value: <<<<< Body (do-block) <<<<< #Nx.Tensor< bf16[16][1500][6][64] > ========== #Nx.Tensor< f32[16][1500][6][64] > >>>>> Initial >>>>> , key: <<<<< Body (do-block) <<<<< #Nx.Tensor< bf16[16][1500][6][64] > ========== #Nx.Tensor< f32[16][1500][6][64] > >>>>> Initial >>>>> }}, %{self_attention: %{value: <<<<< Body (do-block) <<<<< #Nx.Tensor< bf16[16][129][6][64] > ========== #Nx.Tensor< f32[16][129][6][64] > >>>>> Initial >>>>> , key: <<<<< Body (do-block) <<<<< #Nx.Tensor< bf16[16][129][6][64] > ========== #Nx.Tensor< f32[16][129][6][64] > >>>>> Initial >>>>> }, cross_attention: %{value: <<<<< Body (do-block) <<<<< #Nx.Tensor< bf16[16][1500][6][64] > ========== #Nx.Tensor< f32[16][1500][6][64] > >>>>> Initial >>>>> , key: <<<<< Body (do-block) <<<<< #Nx.Tensor< bf16[16][1500][6][64] > ========== #Nx.Tensor< f32[16][1500][6][64] > >>>>> Initial >>>>> }}, %{self_attention: %{value: <<<<< Body (do-block) <<<<< #Nx.Tensor< bf16[16][129][6][64] > ========== #Nx.Tensor< f32[16][129][6][64] > >>>>> Initial >>>>> , key: <<<<< Body (do-block) <<<<< #Nx.Tensor< bf16[16][129][6][64] > ========== #Nx.Tensor< f32[16][129][6][64] > >>>>> Initial >>>>> }, cross_attention: %{value: <<<<< Body (do-block) <<<<< #Nx.Tensor< bf16[16][1500][6][64] > ========== #Nx.Tensor< f32[16][1500][6][64] > >>>>> Initial >>>>> , key: <<<<< Body (do-block) <<<<< #Nx.Tensor< bf16[16][1500][6][64] > ========== #Nx.Tensor< f32[16][1500][6][64] > >>>>> Initial >>>>> }}, %{self_attention: %{value: <<<<< Body (do-block) <<<<< #Nx.Tensor< bf16[16][129][6][64] > ========== #Nx.Tensor< f32[16][129][6][64] > >>>>> Initial >>>>> , key: <<<<< Body (do-block) <<<<< #Nx.Tensor< bf16[16][129][6][64] > ========== #Nx.Tensor< f32[16][129][6][64] > >>>>> Initial >>>>> }, cross_attention: %{value: <<<<< Body (do-block) <<<<< #Nx.Tensor< bf16[16][1500][6][64] > ========== #Nx.Tensor< f32[16][1500][6][64] > >>>>> Initial >>>>> , key: <<<<< Body (do-block) <<<<< #Nx.Tensor< bf16[16][1500][6][64] > ========== #Nx.Tensor< f32[16][1500][6][64] > >>>>> Initial >>>>> }}}, attention_mask: #Nx.Tensor< s64[16][129] >}, "decoder_attention_mask" => #Nx.Tensor< s64[16][1] >, "decoder_input_ids" => #Nx.Tensor< s64[16][1] >, "decoder_position_ids" => #Nx.Tensor< s64[16][1] >, "encoder_hidden_state" => #Nx.Tensor< f32[batch: 16][frames: 1500][mel: 384] >, "input_features" => #Nx.Tensor< f32[batch: 16][frames: 3000][mel: 80] >}, %{"decoder.blocks.0.ffn.output" => %{"bias" => #Nx.Tensor< bf16[384] >, "kernel" => #Nx.Tensor< bf16[1536][384] >}, "encoder.blocks.0.self_attention_norm" => %{"beta" => #Nx.Tensor< bf16[384] >, "gamma" => #Nx.Tensor< bf16[384] >}, "decoder.blocks.2.self_attention.output" => %{"bias" => #Nx.Tensor< bf16[384] >, "kernel" => #Nx.Tensor< bf16[384][384] >}, "decoder.blocks.3.self_attention.value" => %{"bias" => #Nx.Tensor< bf16[384] >, "kernel" => #Nx.Tensor< bf16[384][384] >}, "decoder.blocks.3.self_attention.output" => %{"bias" => #Nx.Tensor< bf16[384] >, "kernel" => #Nx.Tensor< bf16[384][384] >}, "decoder_embedder.position_embedding" => %{"embeddings" => #Nx.Tensor< bf16[448][384] >}, "decoder.blocks.2.self_attention.value" => %{"bias" => #Nx.Tensor< bf16[384] >, "kernel" => #Nx.Tensor< bf16[384][384] >}, "encoder.blocks.1.self_attention.output" => %{"bias" => #Nx.Tensor< bf16[384] >, "kernel" => #Nx.Tensor< bf16[384][384] >}, "decoder.blocks.2.ffn.output" => %{"bias" => #Nx.Tensor< bf16[384] >, "kernel" => #Nx.Tensor< bf16[1536][384] >}, "decoder.blocks.3.self_attention.key" => %{"kernel" => #Nx.Tensor< bf16[384][384] >}, "encoder.blocks.1.output_norm" => %{"beta" => #Nx.Tensor< bf16[384] >, "gamma" => #Nx.Tensor< bf16[384] >}, "encoder.blocks.1.self_attention_norm" => %{"beta" => #Nx.Tensor< bf16[384] >, "gamma" => #Nx.Tensor< bf16[384] >}, "encoder.blocks.1.ffn.output" => %{"bias" => #Nx.Tensor< bf16[384] >, "kernel" => #Nx.Tensor< bf16[1536][384] >}, "encoder.blocks.2.self_attention.key" => %{"kernel" => #Nx.Tensor< bf16[384][384] >}, "encoder.blocks.3.ffn.intermediate" => %{"bias" => #Nx.Tensor< bf16[1536] >, "kernel" => #Nx.Tensor< bf16[384][1536] >}, "decoder.blocks.3.output_norm" => %{"beta" => #Nx.Tensor< bf16[384] >, "gamma" => #Nx.Tensor< bf16[384] >}, "decoder.blocks.3.cross_attention.query" => %{"bias" => #Nx.Tensor< bf16[384] >, "kernel" => #Nx.Tensor< bf16[384][384] >}, "encoder.norm" => %{"beta" => #Nx.Tensor< bf16[384] >, "gamma" => #Nx.Tensor< bf16[384] >}, "decoder.blocks.0.self_attention.output" => %{"bias" => #Nx.Tensor< bf16[384] >, "kernel" => #Nx.Tensor< bf16[384][384] >}, "decoder.blocks.2.cross_attention.output" => %{"bias" => #Nx.Tensor< bf16[384] >, "kernel" => #Nx.Tensor< bf16[384][384] >}, "encoder.blocks.3.ffn.output" => %{"bias" => #Nx.Tensor< bf16[384] >, "kernel" => #Nx.Tensor< bf16[1536][384] >}, "decoder.blocks.3.ffn.output" => %{"bias" => #Nx.Tensor< bf16[384] >, "kernel" => #Nx.Tensor< bf16[1536][384] >}, "encoder_embedder.feature_embedding.conv_1" => %{"bias" => #Nx.Tensor< bf16[384] >, "kernel" => #Nx.Tensor< bf16[3][80][384] >}, "encoder.blocks.0.self_attention.query" => %{"bias" => #Nx.Tensor< bf16[384] >, "kernel" => #Nx.Tensor< bf16[384][384] >}, "encoder.blocks.3.self_attention.query" => %{"bias" => #Nx.Tensor< bf16[384] >, "kernel" => #Nx.Tensor< bf16[384][384] >}, "decoder.blocks.2.cross_attention.query" => %{"bias" => #Nx.Tensor< bf16[384] >, "kernel" => #Nx.Tensor< bf16[384][384] >}, "encoder.blocks.2.output_norm" => %{"beta" => #Nx.Tensor< bf16[384] >, "gamma" => #Nx.Tensor< bf16[384] >}, "encoder.blocks.0.ffn.intermediate" => %{"bias" => #Nx.Tensor< bf16[1536] >, "kernel" => #Nx.Tensor< bf16[384][1536] >}, "decoder.blocks.2.cross_attention_norm" => %{"beta" => #Nx.Tensor< bf16[384] >, "gamma" => #Nx.Tensor< bf16[384] >}, "decoder.blocks.1.self_attention.value" => %{"bias" => #Nx.Tensor< bf16[384] >, "kernel" => #Nx.Tensor< bf16[384][384] >}, "decoder.blocks.0.output_norm" => %{"beta" => #Nx.Tensor< bf16[384] >, "gamma" => #Nx.Tensor< bf16[384] >}, "encoder.blocks.3.self_attention.value" => %{"bias" => #Nx.Tensor< bf16[384] >, "kernel" => #Nx.Tensor< bf16[384][384] >}, "encoder.blocks.2.ffn.intermediate" => %{"bias" => #Nx.Tensor< bf16[1536] >, "kernel" => #Nx.Tensor< bf16[384][1536] >}, "decoder.blocks.2.self_attention.key" => %{"kernel" => #Nx.Tensor< bf16[384][384] >}, "decoder.blocks.2.output_norm" => %{"beta" => #Nx.Tensor< bf16[384] >, "gamma" => #Nx.Tensor< bf16[384] >}, "decoder.blocks.0.cross_attention.value" => %{"bias" => #Nx.Tensor< bf16[384] >, "kernel" => #Nx.Tensor< bf16[384][384] >}, "decoder.blocks.0.cross_attention.key" => %{"kernel" => #Nx.Tensor< bf16[384][384] >}, "decoder.blocks.1.cross_attention.key" => %{"kernel" => #Nx.Tensor< bf16[384][384] >}, "encoder.blocks.0.self_attention.output" => %{"bias" => #Nx.Tensor< bf16[384] >, "kernel" => #Nx.Tensor< bf16[384][384] >}, "encoder.blocks.0.ffn.output" => %{"bias" => #Nx.Tensor< bf16[384] >, "kernel" => #Nx.Tensor< bf16[1536][384] >}, "encoder.blocks.2.self_attention_norm" => %{"beta" => #Nx.Tensor< bf16[384] >, "gamma" => #Nx.Tensor< bf16[384] >}, "encoder.blocks.2.self_attention.output" => %{"bias" => #Nx.Tensor< bf16[384] >, "kernel" => #Nx.Tensor< bf16[384][384] >}, "decoder.blocks.2.cross_attention.value" => %{"bias" => #Nx.Tensor< bf16[384] >, "kernel" => #Nx.Tensor< bf16[384][384] >}, "decoder.blocks.0.self_attention.value" => %{"bias" => #Nx.Tensor< bf16[384] >, "kernel" => #Nx.Tensor< bf16[384][384] >}, "language_modeling_head.output" => %{"kernel" => #Nx.Tensor< bf16[51865][384] >}, "decoder.blocks.2.self_attention_norm" => %{"beta" => #Nx.Tensor< bf16[384] >, ...}, "decoder.blocks.1.ffn.output" => %{...}, ...}} (nx 0.7.0) lib/nx/defn/expr.ex:759: Nx.Defn.Expr.compatible_while!/4 (nx 0.7.0) lib/nx/defn/expr.ex:520: Nx.Defn.Expr.while_vectorized/7 (bumblebee 0.4.2) lib/bumblebee/text/generation.ex:479: Bumblebee.Text.Generation."__defn:greedy__"/7 (bumblebee 0.4.2) lib/bumblebee/text/generation.ex:390: Bumblebee.Text.Generation."__defn:generate_impl__"/8 (bumblebee 0.4.2) lib/bumblebee/audio/speech_to_text_whisper.ex:55: anonymous fn/4 in Bumblebee.Audio.SpeechToTextWhisper.speech_to_text_whisper/5 (nx 0.7.0) lib/nx/defn/compiler.ex:173: Nx.Defn.Compiler.runtime_fun/3 (exla 0.7.0) lib/exla/defn.ex:551: anonymous fn/4 in EXLA.Defn.compile/8 #cell:luahaswjdkkgzubs:1: (file) ``` (`#cell:luahaswjdkkgzubs:1:` is the `Nx.Serving.run` call)I can get it working if I do
which got transcription time of 48mins of audio on a GPU from 240s to 200s.
But when I did a couple tweaks to Bumblebee's codebase, namely:
type
option to theBumblebee.Layers.Decoder.init_cache/3
function, and creating the caches using the given type;type: :bf16
when calling the above function inBumblebee.Audio.Whisper.init_cache/4
It allowed me to set the output type to
bfloat16
as well, bringing the execution time to 170s. Memory consumption during transcription also fell considerably, but it is hard to measure gains accurately there due to XLA preallocations.I would have opened a PR with this, the problem is that currently there's no easy way to get the
type
given toload_model
in theinit_cache/4
callback šSo I opened this issue to discuss what a proper fix could look like. I'm more than happy to open a PR for it then š