AI-Hypercomputer / maxtext

A simple, performant and scalable Jax LLM!
Apache License 2.0
1.47k stars 275 forks source link

Support AoT in 16-vm GPU Llama2 train script #826

Closed jonb377 closed 1 month ago

jonb377 commented 1 month ago

To run nightly AoT compilation tests, we need to support running the 16vm train script using train_compile.py. This follows the pattern set in the TPU train scripts.