elixir-nx / bumblebee

Pre-trained Neural Network models in Axon (+ 🤗 Models integration)
Apache License 2.0
1.27k stars 90 forks source link

Add op-name to rms norm #280

Closed seanmor5 closed 7 months ago

seanmor5 commented 7 months ago

Partially to resolve https://github.com/elixir-nx/axon/issues/544

Llama specifically computes RMS norm in f32 always (this is probably the case with other models). The way to reflect this in an Axon mixed precision policy is to apply the policy with a modifier except: [:rms_norm] and it will apply it to all layers by op_name that match that. So we need rms norm here to have an op-name

ityonemo commented 7 months ago

Empirically, it's the embedding layer that needs be fp32! Labelling the layer is probably good anyways, but if documentation is given for guidance on how to quantize these models, we should probably reflect these results?

Disabling quantization on Embedding only:

auth_token = System.get_env("HF_AUTH_TOKEN")
Nx.default_backend({EXLA.Backend, client: :host})
model = {:hf, "meta-llama/Llama-2-7b-chat-hf", auth_token: auth_token}

{:ok, m} = Bumblebee.load_model(model)
{:ok, t} = Bumblebee.load_tokenizer(model)
{:ok, g} = Bumblebee.load_generation_config(model)

bf = {:bf, 16}
policy = Axon.MixedPrecision.create_policy(params: bf, compute: bf, output: bf)

mp_model = Axon.MixedPrecision.apply_policy(m.model, policy, &(&1.op_name != :embedding))
m2 = %{m | model: mp_model}

serving =
  Bumblebee.Text.generation(m2, t, g,
    defn_options: [compiler: EXLA, compiler_options: [client: :cuda, lazy_transfers: :always]]
  )

%{results: [%{text: text}]} =
  Nx.Serving.run(serving, "[INST] <<SYS>>\nYou are a bot.\n<</SYS>>\n\nHi, bot![/INST]")

text |> dbg

result:

"[INST] <<SYS>>\nYou are a bot.\n<</SYS>>\n\nHi, bot![/INST] Hello there! *greeting* I'm just an AI assistant, here to help you with any questions or tasks you may have. How can I assist you today? 😊"

Disabling quantization on RmsNorm only:

auth_token = System.get_env("HF_AUTH_TOKEN")
Nx.default_backend({EXLA.Backend, client: :host})
model = {:hf, "meta-llama/Llama-2-7b-chat-hf", auth_token: auth_token}

{:ok, m} = Bumblebee.load_model(model)
{:ok, t} = Bumblebee.load_tokenizer(model)
{:ok, g} = Bumblebee.load_generation_config(model)

bf = {:bf, 16}
policy = Axon.MixedPrecision.create_policy(params: bf, compute: bf, output: bf)

mp_model = Axon.MixedPrecision.apply_policy(m.model, policy, &(&1.op_name != :rms_norm))
m2 = %{m | model: mp_model}

serving =
  Bumblebee.Text.generation(m2, t, g,
    defn_options: [compiler: EXLA, compiler_options: [client: :cuda, lazy_transfers: :always]]
  )

%{results: [%{text: text}]} =
  Nx.Serving.run(serving, "[INST] <<SYS>>\nYou are a bot.\n<</SYS>>\n\nHi, bot![/INST]")

text |> dbg

result:

"[INST] <<SYS>>\nYou are a bot.\n<</SYS>>\n\nHi, bot![/INST] Har Mittelexists\n\nAnd of `` counted...