Closed steveepreston closed 2 weeks ago
Hi @steveepreston,
I able to execute the code using the Gemma model, and it worked without any issues. For the Llama model, however, could you please reach out to the Llama team for further assistance? Please refer to the Gist file for more details.
Thank you.
Thank you for attention @Gopi-Uppari
Yes, gemma
successfully executed in my test too. (although gemma-2-9b-it
thrown OOM on TPU).
Problem is about llama
model.
ok, i will try to create another issue there also.
Could you please confirm if this issue is resolved for you with the above comment ? Please feel free to close the issue if it is resolved ?
Thank you.
Problem not resolved and I've moved to PyTorch.
Maybe I'll back to follow and solve this in future. There is still no example for Llama3CausalLM
+XLA
in the web.
Hi @steveepreston
Variable paths in Llama are different from Gemma so the layout map that works for Gemma doesn't work for Llama (see here).
Recently, get_layout_map
is added for Llama here. So instead of specifying the layout map manually, you can use this function: layout_map = keras_nlp.models.LlamaBackbone.get_layout_map(device_mesh)
.
We haven't added get_layout_map
for Llama3 yet but if it has the same architecture as Llama, you can copy the layout map from here.
@SamanehSaadat Thanks for the explanation. it sounds fine.
just to confirm: is get_layout_map()
currently available only for LlamaBackbone
and GemmaBackbone
?
What should we do for other models, such as distil_bert
and phi3
and so on?
@steveepreston That's correct, it's only available for Llama and Gemma right now but we're planning to add support for all other models soon. Here is the tracking issue: https://github.com/keras-team/keras-hub/issues/1689
@SamanehSaadat Got it, thanks.
Trying to load llama-3.2 on TPU VM v3-8 via this:
but it throws this Error:
note: i get layout_map code from This Example. i don't know if problem is from
layout_map
orLlama3CausalLM