Open jonb377 opened 1 week ago
@gobbleturk Is there a way to detect a multihost environment? I see GPU and inference tests failing due to this change.
@gobbleturk Is there a way to detect a multihost environment? I see GPU and inference tests failing due to this change.
We have to initialize the jax distributed system before the runtime backend, so it doesn't seem this can be inferred. However we can create an additional option to MaxText - either something like "is_single_host" or "should_initialize_jax_distributed_system" in the config as a boolean option, wdyt?
@gobbleturk Is there a way to detect a multihost environment? I see GPU and inference tests failing due to this change.
We have to initialize the jax distributed system before the runtime backend, so it doesn't seem this can be inferred. However we can create an additional option to MaxText - either something like "is_single_host" or "should_initialize_jax_distributed_system" in the config as a boolean option, wdyt?
Do we still need to initialize the JDI even for single host? We may want to support supplying correct args in this case - process_id = 0, num_processes=1 coordinator_ip = get_own_ip()
For the TPU tests, perhaps I should try to fix the backend being initialized - the error is in pyconfig, so I don't expect the backend to be up.
The GPU failures can be fixed by setting the coordinator address in the environment. But that does seem overkill for single-host...
Nightly tests are failing due to jax.distributed not being initialized in the synchronous checkpointing case.