Closed seanmor5 closed 7 months ago
Empirically, it's the embedding layer that needs be fp32! Labelling the layer is probably good anyways, but if documentation is given for guidance on how to quantize these models, we should probably reflect these results?
auth_token = System.get_env("HF_AUTH_TOKEN")
Nx.default_backend({EXLA.Backend, client: :host})
model = {:hf, "meta-llama/Llama-2-7b-chat-hf", auth_token: auth_token}
{:ok, m} = Bumblebee.load_model(model)
{:ok, t} = Bumblebee.load_tokenizer(model)
{:ok, g} = Bumblebee.load_generation_config(model)
bf = {:bf, 16}
policy = Axon.MixedPrecision.create_policy(params: bf, compute: bf, output: bf)
mp_model = Axon.MixedPrecision.apply_policy(m.model, policy, &(&1.op_name != :embedding))
m2 = %{m | model: mp_model}
serving =
Bumblebee.Text.generation(m2, t, g,
defn_options: [compiler: EXLA, compiler_options: [client: :cuda, lazy_transfers: :always]]
)
%{results: [%{text: text}]} =
Nx.Serving.run(serving, "[INST] <<SYS>>\nYou are a bot.\n<</SYS>>\n\nHi, bot![/INST]")
text |> dbg
"[INST] <<SYS>>\nYou are a bot.\n<</SYS>>\n\nHi, bot![/INST] Hello there! *greeting* I'm just an AI assistant, here to help you with any questions or tasks you may have. How can I assist you today? 😊"
auth_token = System.get_env("HF_AUTH_TOKEN")
Nx.default_backend({EXLA.Backend, client: :host})
model = {:hf, "meta-llama/Llama-2-7b-chat-hf", auth_token: auth_token}
{:ok, m} = Bumblebee.load_model(model)
{:ok, t} = Bumblebee.load_tokenizer(model)
{:ok, g} = Bumblebee.load_generation_config(model)
bf = {:bf, 16}
policy = Axon.MixedPrecision.create_policy(params: bf, compute: bf, output: bf)
mp_model = Axon.MixedPrecision.apply_policy(m.model, policy, &(&1.op_name != :rms_norm))
m2 = %{m | model: mp_model}
serving =
Bumblebee.Text.generation(m2, t, g,
defn_options: [compiler: EXLA, compiler_options: [client: :cuda, lazy_transfers: :always]]
)
%{results: [%{text: text}]} =
Nx.Serving.run(serving, "[INST] <<SYS>>\nYou are a bot.\n<</SYS>>\n\nHi, bot![/INST]")
text |> dbg
"[INST] <<SYS>>\nYou are a bot.\n<</SYS>>\n\nHi, bot![/INST] Har
Mittelexists\n\nAnd of `` counted...
Partially to resolve https://github.com/elixir-nx/axon/issues/544
Llama specifically computes RMS norm in f32 always (this is probably the case with other models). The way to reflect this in an Axon mixed precision policy is to apply the policy with a modifier
except: [:rms_norm]
and it will apply it to all layers by op_name that match that. So we need rms norm here to have an op-name