google / maxtext

A simple, performant and scalable Jax LLM!
Apache License 2.0
1.44k stars 263 forks source link

Enable async checkpointing for GPU. #697

Closed yangyuwei closed 2 months ago

yangyuwei commented 2 months ago

Make sure we always call initialize_jax_for_gpu() to initialize JAX when running on GPUs. This change will allow us to make Orbax async checkpointing for MaxText on GPUs.

yangyuwei commented 2 months ago

Actually, the code changes have been included by https://github.com/google/maxtext/commit/24dc66cc4fbad1c13b093a8cae667f883a347a8e and already merged. I'll close this PR.