AI-Hypercomputer / maxtext

A simple, performant and scalable Jax LLM!
Apache License 2.0
1.47k stars 275 forks source link

Adding Mixtral-8x22b #845

Closed rdyro closed 2 weeks ago

rdyro commented 4 weeks ago
rdyro commented 4 weeks ago

End-to-end tests on ml-auto-solutions are a PR here

peregilk commented 1 week ago

@rdyro I encountered OOM (Out of Memory) errors when loading a Llama 8B model on a v4-8 after this commit. It appears to be related to llama_or_mistral_chkpt.py. While I haven’t pinpointed the exact cause, reverting to the version of MaxText/llama_or_mistral_ckpt.py from commit aef1bb0b60c89b6c9876e89ce0b0c35b759235d7 resolves the issue.

I am able to reproduce the error by using the not-yet-merged script from @A9isha: llama_or_mistral_orbax_to_huggingface.py, where failure occurs at line 97. It’s likely that similar loading processes might trigger the same error.