elixir-nx / bumblebee

Pre-trained Neural Network models in Axon (+ 🤗 Models integration)
Apache License 2.0
1.27k stars 90 forks source link

`TextEmbedding` crashes when both Mean Pooling and `compile` opts is specified #315

Closed thiagopromano closed 6 months ago

thiagopromano commented 6 months ago

If the output pool is set to :mean_pooling and batch_size/sequence_length are specified, the dimensions for the mean pooling don't match, as can be seen from the example below. I've tested on the HEAD of the main branch today and it still crashes.

Mix.install(
  [
    {:kino, "~> 0.12.0"},
    {:nx, "~> 0.5"},
    {:bumblebee, "~> 0.4.2"},
    {:exla, "~> 0.6.4"}
  ],
  config: [
    nx: [default_backend: EXLA.Backend]
  ]
)

model_repository = {:hf, "sentence-transformers/all-MiniLM-L6-v2"}
{:ok, model_info} = Bumblebee.load_model(model_repository)
{:ok, tokenizer} = Bumblebee.load_tokenizer(model_repository)

serving =
  Bumblebee.Text.text_embedding(model_info, tokenizer,
    output_pool: :mean_pooling,
    compile: [batch_size: 16, sequence_length: 128]
  )

Nx.Serving.run(serving, "Cats are cute.")

Result:

** (ArgumentError) cannot broadcast tensor of dimensions {16, 384} to {16, 128, 1}
    (nx 0.6.4) lib/nx/shape.ex:345: Nx.Shape.binary_broadcast/4
    (nx 0.6.4) lib/nx.ex:5500: Nx.devectorized_element_wise_bin_op/4
    (bumblebee 0.4.2) lib/bumblebee/text/text_embedding.ex:56: anonymous fn/6 in Bumblebee.Text.TextEmbedding.text_embedding/3
    (nx 0.6.4) lib/nx/defn/compiler.ex:158: Nx.Defn.Compiler.runtime_fun/3
    (nx 0.6.4) lib/nx/defn/evaluator.ex:83: Nx.Defn.Evaluator.precompile/3
    (nx 0.6.4) lib/nx/defn/evaluator.ex:61: Nx.Defn.Evaluator.__compile__/4
    (nx 0.6.4) lib/nx/defn.ex:305: Nx.Defn.compile/3
    /Users/.../Library/Application Support/livebook/autosaved/2023_12_13/20_48_6ddh/untitled_notebook.livemd#cell:csdm2pjr6ugsry4ylprjvwh36kkq6vlj:10: (file)
jonatanklosko commented 6 months ago

The model outputs several tensors, and by default the serving pick :pooled_state, which is already pooled (in a different way). I improved the error in 96935d8e88fbaed82370bfd684c348e29831ccd2.

So in this case what you want is output_attribute: :hidden_state and output_pool: :mean_pooling :)

thiagopromano commented 6 months ago

Nice, thank you!