elixir-nx / bumblebee

Pre-trained Neural Network models in Axon (+ 🤗 Models integration)
Apache License 2.0
1.26k stars 90 forks source link

Add Gemma #358

Closed seanmor5 closed 4 months ago

seanmor5 commented 4 months ago

Resolves #357

Gemma has attention_bias config, which is similar to our use_qkv_bias but not really accurate because there is attention output bias too. I added use_attention_bias instead, but wondering if we should change all instances to use_attention_bias ?

kurtome commented 4 months ago

I'm not familiar with Bumblebee, so I may be doing something wrong, but when I tried out this branch my Phoenix app crashes when I try to load the model:

defmodule NoEx.Application do
  use Application

  @hf_token "abcd..."

  @impl true
  def start(_type, _args) do
    IO.inspect("NoEx.Application.start")

    {:ok, nx_model_info} =
      Bumblebee.load_model(
        {:hf, "google/gemma-7b-it", [auth_token: @hf_token]},
        spec_overrides: [num_labels: 10]
      )
          IO.inspect(nx_model_info)

    {:ok, nx_tokenizer} =
      Bumblebee.load_tokenizer({:hf, "google/gemma-7b-it", [auth_token: @hf_token]})

    IO.inspect(nx_tokenizer)

    {:ok, nx_gen_config} =
      Bumblebee.load_generation_config({:hf, "google/gemma-7b-it", [auth_token: @hf_token]})

    IO.inspect(nx_gen_config)

    children = [
      NoExWeb.Telemetry,
      NoEx.Repo,
      {DNSCluster, query: Application.get_env(:no_ex, :dns_cluster_query) || :ignore},
      {Phoenix.PubSub, name: NoEx.PubSub},
      # Nx
      {Nx.Serving, serving: nx_serving, name: NoEx.Serving, batch_timeout: 100},
      # Start the Finch HTTP client for sending emails
      {Finch, name: NoEx.Finch},
      # Start a worker by calling: NoEx.Worker.start_link(arg)
      # {NoEx.Worker, arg},
      # Start to serve requests, typically the last entry
      NoExWeb.Endpoint
    ]

    # See https://hexdocs.pm/elixir/Supervisor.html
    # for other strategies and supported options
    opts = [strategy: :one_for_one, name: NoEx.Supervisor]
    Supervisor.start_link(children, opts)
  end
end

It never makes it past the Bumblebee.load_model line.

$ mix phx.server

Compiling 1 file (.ex)
"NoEx.Application.start"
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1709498689.304461 5486422 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.
[1]    74844 killed     mix phx.server
jonatanklosko commented 4 months ago

@seanmor5 the official checkpoint ties embeddings, so I changed loading to "language_modeling_head.output" => "model.embed_tokens" (we could add config, but unlikely there's an untied version, and we want to actually address it eventually). I generated the tiny config to not include the tied embeddings.

FTR I also used the config values from hf/transformers tests, which is even smaller, generally we want the tiny checkpoints to be as small as possible :) The utils/create_dummy_models.py in hf/transformers didn't work for me (also didn't for llama), though it did for Bert in the past. So instead of digging into this too much, I just created it by hand:

Code ```python from transformers import GemmaConfig, GemmaModel, GemmaForCausalLM, GemmaForSequenceClassification config = GemmaConfig( # vocab_size=99, vocab_size=1024, hidden_size=32, num_hidden_layers=2, num_attention_heads=4, num_key_value_heads=2, intermediate_size=37, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=16, is_decoder=False, initializer_range=0.02, pad_token_id=0, head_dim=8, ) for c in [GemmaModel, GemmaForCausalLM, GemmaForSequenceClassification]: name = c.__name__ c(config).save_pretrained(f"tmp/bumblebee-testing/tiny-random-{name}", repo_id=f"bumblebee-testing/tiny-random-{name}", push_to_hub=True) ```

@seanmor5 I also added you to the bumblebee-testing org, so you don't have to push random repos into your account :p

jonatanklosko commented 4 months ago

I added use_attention_bias instead, but wondering if we should change all instances to use_attention_bias?

Yeah I think it's fine to change use_qkv_bias to use_attention_bias for consistent naming, and the model implementation passes it either to qkv or to qkvo. I will do this in a separate commit!

jonatanklosko commented 4 months ago

@kurtome the "killed" log probably means that the OS killed the process because it was getting close to OOM. This is on CPU right? How much RAM do you have? You probably want to do load_model(..., type: :f16) to reduce the memory usage, though 7b model may be too much for the CPU to run in a reasonable speed, you can try the 2b checkpoint too.