keras-team / keras-hub

Pretrained model hub for Keras 3
Apache License 2.0
804 stars 243 forks source link

[Help][BUG] `KeyError: 'lm_head.weight'` on loading llama 3.2 #1920

Closed steveepreston closed 2 weeks ago

steveepreston commented 1 month ago

Trying to load llama-3.2 on TPU VM v3-8 via this:

device_mesh = keras.distribution.DeviceMesh((1, 8), ["batch", "model"], devices=keras.distribution.list_devices())
layout_map = keras.distribution.LayoutMap(device_mesh)
layout_map["token_embedding/embeddings"] = ("model", None)
layout_map["decoder_block.*attention.*(query|key|value)/kernel"] = ("model", None, None)
layout_map["decoder_block.*attention_output/kernel"] = ("model", None, None)
layout_map["decoder_block.*ffw_gating.*/kernel"] = (None, "model")
layout_map["decoder_block.*ffw_linear/kernel"] = ("model", None)
model_parallel = keras.distribution.ModelParallel(layout_map=layout_map, batch_dim_name="batch")
keras.distribution.set_distribution(model_parallel)

model = keras_nlp.models.Llama3CausalLM.from_preset("meta-llama/Llama-3.2-3B-Instruct")

but it throws this Error:

KeyError: 'lm_head.weight'

note: i get layout_map code from This Example. i don't know if problem is from layout_map or Llama3CausalLM

Gopi-Uppari commented 1 month ago

Hi @steveepreston,

I able to execute the code using the Gemma model, and it worked without any issues. For the Llama model, however, could you please reach out to the Llama team for further assistance? Please refer to the Gist file for more details.

Thank you.

steveepreston commented 1 month ago

Thank you for attention @Gopi-Uppari

Yes, gemma successfully executed in my test too. (although gemma-2-9b-it thrown OOM on TPU). Problem is about llama model.

ok, i will try to create another issue there also.

Gopi-Uppari commented 1 month ago

Could you please confirm if this issue is resolved for you with the above comment ? Please feel free to close the issue if it is resolved ?

Thank you.

steveepreston commented 1 month ago

Problem not resolved and I've moved to PyTorch. Maybe I'll back to follow and solve this in future. There is still no example for Llama3CausalLM+XLA in the web.

SamanehSaadat commented 2 weeks ago

Hi @steveepreston

Variable paths in Llama are different from Gemma so the layout map that works for Gemma doesn't work for Llama (see here).

Recently, get_layout_map is added for Llama here. So instead of specifying the layout map manually, you can use this function: layout_map = keras_nlp.models.LlamaBackbone.get_layout_map(device_mesh).

We haven't added get_layout_map for Llama3 yet but if it has the same architecture as Llama, you can copy the layout map from here.

steveepreston commented 2 weeks ago

@SamanehSaadat Thanks for the explanation. it sounds fine.

just to confirm: is get_layout_map() currently available only for LlamaBackbone and GemmaBackbone? What should we do for other models, such as distil_bert and phi3 and so on?

SamanehSaadat commented 2 weeks ago

@steveepreston That's correct, it's only available for Llama and Gemma right now but we're planning to add support for all other models soon. Here is the tracking issue: https://github.com/keras-team/keras-hub/issues/1689

steveepreston commented 2 weeks ago

@SamanehSaadat Got it, thanks.