Closed yangyuwei closed 2 months ago
Make sure we always call initialize_jax_for_gpu() to initialize JAX when running on GPUs. This change will allow us to make Orbax async checkpointing for MaxText on GPUs.
initialize_jax_for_gpu()
Actually, the code changes have been included by https://github.com/google/maxtext/commit/24dc66cc4fbad1c13b093a8cae667f883a347a8e and already merged. I'll close this PR.
Make sure we always call
initialize_jax_for_gpu()
to initialize JAX when running on GPUs. This change will allow us to make Orbax async checkpointing for MaxText on GPUs.