google / maxtext

A simple, performant and scalable Jax LLM!
Apache License 2.0
1.39k stars 247 forks source link

Implement restore with Orbax emergency checkpoint manager #740

Closed xuefgu closed 3 weeks ago

xuefgu commented 3 weeks ago

This is to fully unblock scale testing on both TPUs and CPUs. However, the ways the process ids and the coordinator address are communicated (through local files and distributed R/W via GCS) are not production-ready. Follow-up changes are required to make the recovery fully automatic.