AI-Hypercomputer / maxtext

A simple, performant and scalable Jax LLM!
Apache License 2.0
1.47k stars 275 forks source link

Initialize jax distributed when checkpointing is enabled #895

Open jonb377 opened 1 week ago

jonb377 commented 1 week ago

Nightly tests are failing due to jax.distributed not being initialized in the synchronous checkpointing case.

jonb377 commented 5 days ago

@gobbleturk Is there a way to detect a multihost environment? I see GPU and inference tests failing due to this change.

gobbleturk commented 5 days ago

@gobbleturk Is there a way to detect a multihost environment? I see GPU and inference tests failing due to this change.

We have to initialize the jax distributed system before the runtime backend, so it doesn't seem this can be inferred. However we can create an additional option to MaxText - either something like "is_single_host" or "should_initialize_jax_distributed_system" in the config as a boolean option, wdyt?

gobbleturk commented 5 days ago

@gobbleturk Is there a way to detect a multihost environment? I see GPU and inference tests failing due to this change.

We have to initialize the jax distributed system before the runtime backend, so it doesn't seem this can be inferred. However we can create an additional option to MaxText - either something like "is_single_host" or "should_initialize_jax_distributed_system" in the config as a boolean option, wdyt?

Do we still need to initialize the JDI even for single host? We may want to support supplying correct args in this case - process_id = 0, num_processes=1 coordinator_ip = get_own_ip()

jonb377 commented 5 days ago

For the TPU tests, perhaps I should try to fix the backend being initialized - the error is in pyconfig, so I don't expect the backend to be up.

The GPU failures can be fixed by setting the coordinator address in the environment. But that does seem overkill for single-host...