elixir-nx / bumblebee

Pre-trained Neural Network models in Axon (+ 🤗 Models integration)
Apache License 2.0
1.27k stars 90 forks source link

Upgrading bumblebee from 0.3.0 to 0.3.1 breaks with the `FunctionClauseError` #235

Closed munjalpatel closed 11 months ago

munjalpatel commented 11 months ago

I am using bumblebee to do sentiment and emotion analysis.

Everything works fine when using bumblebee v0.3.0 and exla v0.5.3

mix.exs

defp deps do
    [
      ...
      {:bumblebee, "~> 0.3.0"},
      {:exla, ">= 0.0.0"}
    ]
end

mix.lock

[
    "bumblebee": {:hex, :bumblebee, "0.3.0", "ad6294b39b8fb2212620e9ed9fbebc936c574ae146d3f49b3e855e1254f7c981", [:mix], [{:axon, "~> 0.5.0", [hex: :axon, repo: "hexpm", optional: false]}, {:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: false]}, {:jason, "~> 1.4.0", [hex: :jason, repo: "hexpm", optional: false]}, {:nx, "~> 0.5.0", [hex: :nx, repo: "hexpm", optional: false]}, {:nx_image, "~> 0.1.0", [hex: :nx_image, repo: "hexpm", optional: false]}, {:nx_signal, "~> 0.1.0", [hex: :nx_signal, repo: "hexpm", optional: false]}, {:progress_bar, "~> 2.0", [hex: :progress_bar, repo: "hexpm", optional: false]}, {:tokenizers, "~> 0.3.1", [hex: :tokenizers, repo: "hexpm", optional: false]}, {:unpickler, "~> 0.1.0", [hex: :unpickler, repo: "hexpm", optional: false]}, {:unzip, "0.8.0", [hex: :unzip, repo: "hexpm", optional: false]}], "hexpm", "477ed5e15d4a5b18343086bed83e2990ca2ba67e0dc9e2d57518bda4cca4c95e"},
    "exla": {:hex, :exla, "0.5.3", "f9496980a447ec2564b1646f89ee64379faef608c4a3ab6059f6a55117235e63", [:make, :mix], [{:elixir_make, "~> 0.6", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:nx, "~> 0.5.1", [hex: :nx, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}, {:xla, "~> 0.4.4", [hex: :xla, repo: "hexpm", optional: false]}], "hexpm", "41c0523217e9ed4007c005be08ec5692b76f509f0103a2c4f6b2ed016213a48e"}
]

But when I upgrade to bumblebee v0.3.1 and exla v0.6.0

mix.exs

defp deps do
    [
      ...
      {:bumblebee, "== 0.3.1"},
      {:exla, ">= 0.6.0"}
    ]
end

mix.lock

[
    "bumblebee": {:hex, :bumblebee, "0.3.1", "311f21930019a7702306ff50a5d9eb10b766d8b1e1c7a1076c0f8b0a10e13bf3", [:mix], [{:axon, "~> 0.6.0", [hex: :axon, repo: "hexpm", optional: false]}, {:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: false]}, {:jason, "~> 1.4.0", [hex: :jason, repo: "hexpm", optional: false]}, {:nx, "~> 0.6.0", [hex: :nx, repo: "hexpm", optional: false]}, {:nx_image, "~> 0.1.0", [hex: :nx_image, repo: "hexpm", optional: false]}, {:nx_signal, "~> 0.1.0", [hex: :nx_signal, repo: "hexpm", optional: false]}, {:progress_bar, "~> 3.0", [hex: :progress_bar, repo: "hexpm", optional: false]}, {:safetensors, "~> 0.1.1", [hex: :safetensors, repo: "hexpm", optional: false]}, {:tokenizers, "~> 0.4", [hex: :tokenizers, repo: "hexpm", optional: false]}, {:unpickler, "~> 0.1.0", [hex: :unpickler, repo: "hexpm", optional: false]}, {:unzip, "0.8.0", [hex: :unzip, repo: "hexpm", optional: false]}], "hexpm", "a3e00e0fb0a3c6d99a78afcfd539d37f6bdff09da197ce27a9b2af41bae6a0cc"},
    "exla": {:hex, :exla, "0.6.0", "af63e45ce41ad25630967923147d14292a0cc48e507b8a3cf3bf3d5483099a28", [:make, :mix], [{:elixir_make, "~> 0.6", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:nx, "~> 0.6.0", [hex: :nx, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}, {:xla, "~> 0.5.0", [hex: :xla, repo: "hexpm", optional: false]}], "hexpm", "5f6a4a105ea9ab207b9aa4de5a294730e2bfe9639f4b8d37a7c00da131090d7a"},
]

I get the following error:

** (FunctionClauseError) no function clause matching in Nx.BinaryBackend.to_binary/1
    (nx 0.6.0) lib/nx/binary_backend.ex:183: Nx.BinaryBackend.to_binary(#Nx.Tensor<
  u32
  EXLA.Backend<host:0, 0.1673869229.3388342291.236609>
  0
>)
    (nx 0.6.0) lib/nx/binary_backend.ex:355: Nx.BinaryBackend.pad/4
    (nx 0.6.0) lib/nx.ex:5431: Nx.apply_vectorized/2
    (exla 0.6.0) lib/exla/defn/buffers.ex:107: EXLA.Defn.Buffers.from_nx!/3
    (exla 0.6.0) lib/exla/defn/buffers.ex:25: EXLA.Defn.Buffers.filter_by_indexes_map/4
    (exla 0.6.0) lib/exla/defn/buffers.ex:25: EXLA.Defn.Buffers.filter_by_indexes_map/4
    (exla 0.6.0) lib/exla/defn.ex:342: EXLA.Defn.maybe_outfeed/7
    (stdlib 4.2) timer.erl:235: :timer.tc/1
Function: #Function<28.43156718/0 in Nx.Serving.server_maybe_task/1>
    Args: []

** (stop) exited in: Nx.Serving.local_batched_run(MyApp.SentimentAnalyzer, "")
    ** (EXIT) an exception was raised:
        ** (FunctionClauseError) no function clause matching in Nx.BinaryBackend.to_binary/1
            (nx 0.6.0) lib/nx/binary_backend.ex:183: Nx.BinaryBackend.to_binary(#Nx.Tensor<
  u32
  EXLA.Backend<host:0, 0.1673869229.3388342291.236610>
  0
>)
            (nx 0.6.0) lib/nx/binary_backend.ex:355: Nx.BinaryBackend.pad/4
            (nx 0.6.0) lib/nx.ex:5431: Nx.apply_vectorized/2
            (exla 0.6.0) lib/exla/defn/buffers.ex:107: EXLA.Defn.Buffers.from_nx!/3
            (exla 0.6.0) lib/exla/defn/buffers.ex:25: EXLA.Defn.Buffers.filter_by_indexes_map/4
            (exla 0.6.0) lib/exla/defn/buffers.ex:25: EXLA.Defn.Buffers.filter_by_indexes_map/4
            (exla 0.6.0) lib/exla/defn.ex:342: EXLA.Defn.maybe_outfeed/7
            (stdlib 4.2) timer.erl:235: :timer.tc/1
    (nx 0.6.0) lib/nx/serving.ex:817: Nx.Serving.local_batched_run!/3
    (elixir 1.14.2) lib/task/supervised.ex:89: Task.Supervised.invoke_mfa/2
    (elixir 1.14.2) lib/task/supervised.ex:34: Task.Supervised.reply/4
    (stdlib 4.2) proc_lib.erl:240: :proc_lib.init_p_do_apply/3

application.ex

defmodule MyAppActions.Application do
  # See https://hexdocs.pm/elixir/Application.html
  # for more information on OTP Applications
  @moduledoc false

  use Application

  @impl true
  def start(_type, _args) do
    children = [
      # Start the Telemetry supervisor
      MyAppActions.Telemetry,

      # Start the Endpoint (http/https)
      MyAppActions.Endpoint,
      MyAppActions.Repo,
      {Phoenix.PubSub, name: MyAppActions.PubSub},
      {Nx.Serving, serving: sentiment_analyzer_serving(), name: MyApp.SentimentAnalyzer},
      {Nx.Serving, serving: emotion_analyzer_serving(), name: MyApp.EmotionAnalyzer}
    ]

    opts = [strategy: :one_for_one, name: MyAppActions.Supervisor]
    Supervisor.start_link(children, opts)
  end

  def sentiment_analyzer_serving() do
    {:ok, model_info} =
      Bumblebee.load_model({:hf, "finiteautomata/bertweet-base-sentiment-analysis"})

    {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "vinai/bertweet-base"})

    Bumblebee.Text.text_classification(model_info, tokenizer,
      compile: [batch_size: 10, sequence_length: 512],
      defn_options: [compiler: EXLA]
    )
  end

  def emotion_analyzer_serving() do
    {:ok, model_info} =
      Bumblebee.load_model({:hf, "j-hartmann/emotion-english-distilroberta-base"})

    {:ok, tokenizer} =
      Bumblebee.load_tokenizer({:hf, "j-hartmann/emotion-english-distilroberta-base"})

    Bumblebee.Text.text_classification(model_info, tokenizer,
      compile: [batch_size: 10, sequence_length: 512],
      defn_options: [compiler: EXLA]
    )
  end

  @impl true
  def config_change(changed, _new, removed) do
    MyAppActions.Endpoint.config_change(changed, removed)
    :ok
  end
end

analyzer.ex

def analyze(data) do
    Enum.map(data, fn %{id: id, text: text} ->
      sentiment_task =
        Task.async(fn ->
          Nx.Serving.batched_run(MyApp.SentimentAnalyzer, text || "")
        end)

      emotion_task =
        Task.async(fn ->
          Nx.Serving.batched_run(MyApp.EmotionAnalyzer, text || "")
        end)

      {id, {sentiment_task, emotion_task}}
    end)
end
jonatanklosko commented 11 months ago

Hey @munjalpatel, it's an Nx bug that surfaced because of a recent change in Bumblebee. Fixed in https://github.com/elixir-nx/nx/pull/1283. You can try the fix by using {:nx, github: "elixir-nx/nx", sparse: "nx", override: true} as a dependency :)

munjalpatel commented 11 months ago

Thanks @jonatanklosko that works!