elixir-nx / bumblebee

Pre-trained Neural Network models in Axon (+ 🤗 Models integration)
Apache License 2.0
1.31k stars 95 forks source link

Serving the `OpenAssistant/oasst-sft-1-pythia-12b` model fails #223

Closed mathieul closed 1 year ago

mathieul commented 1 year ago

I tried to test theoasst-sft-1-pythia-12b Open Assistant model as described by Sean Moriarty in the Open-Source Elixir Alternatives to ChatGPT post. But when I get to creating the serving instance, I get the following error:

Error:
(FunctionClauseError) no function clause matching in Bumblebee.Text.Generation.generation/4

  The following arguments were given to Bumblebee.Text.Generation.generation/4:

  #1
  %{model: #Axon<...>, params: %{...}, spec: %Bumblebee.Text.GptNeoX{...}}

  #2
  %Bumblebee.Text.GptNeoXTokenizer{...}

  #3
  [defn_options: [compiler: EXLA]]

  #4
  []

  Attempted function clauses (showing 1 out of 1):

  def generation(model_info, tokenizer, %Bumblebee.Text.GenerationConfig{} = generation_config, opts)

(bumblebee 0.3.0) lib/bumblebee/text/generation.ex:54: Bumblebee.Text.Generation.generation/4
(sandbox 0.1.0) lib/sandbox.ex:12: Sandbox.test_open_assistant/1

Here's the code:

Mix.install([
  {:bumblebee, github: "elixir-nx/bumblebee"},
  {:exla, "~> 0.5"}
])

defmodule Sandbox do
  def test_open_assistant(prompt) do
    {:ok, model} = Bumblebee.load_model({:hf, "OpenAssistant/oasst-sft-1-pythia-12b"})
    {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "OpenAssistant/oasst-sft-1-pythia-12b"})

    serving = Bumblebee.Text.generation(model, tokenizer, defn_options: [compiler: EXLA])
    Nx.Serving.run(serving, "<|prompter|>#{prompt}<|endoftext|><|assistant|>")
  end
end

Bumblebee.Text.Generation.generation/4 expects a %Text.GenerationConfig{} as the 3rd parameter, but instead gets the defn options as a list. I'm still a newbie in deep learning (but not for long as I'm currently enjoying Sean's great "Machine Learning in Elixir" book ;), so I'm not sure where to go from there.

mathieul commented 1 year ago

Now that I re-read the issue, I'm realizing that the call to Bumblebee.Text.generation(model, tokenizer, defn_options: [compiler: EXLA]) directly matches to Bumblebee.Text.Generation.generation/4 (sorry, newbie), so I'm going to spend a bit more time on this before closing it.

jonatanklosko commented 1 year ago

Hey @mathieul, we changed the API since that article to encapsulate generation-config in a struct, so what you want is this:

{:ok, model} = Bumblebee.load_model({:hf, "OpenAssistant/oasst-sft-1-pythia-12b"})
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "OpenAssistant/oasst-sft-1-pythia-12b"})
{:ok, generation_config} = Bumblebee.load_generation_config({:hf, "OpenAssistant/oasst-sft-1-pythia-12b"})

serving = Bumblebee.Text.generation(model, tokenizer, generation_config, defn_options: [compiler: EXLA])

Nx.Serving.run(serving, "<|prompter|>#{prompt}<|endoftext|><|assistant|>")
josevalim commented 1 year ago

@seanmor5 in case you want to update the article :)

mathieul commented 1 year ago
{:ok, generation_config} = Bumblebee.load_generation_config({:hf, "OpenAssistant/oasst-sft-1-pythia-12b"})

Thanks a lot for your quick answer @jonatanklosko!