When using MaxText with slurm, our jobs only see one GPU per node because jax.distributed assumes one GPU per process when used with slurm (see the Jax docs.
This behavior can be overridden by passing local_device_ids to jax.distributed.initialize, so one way to fix this is to change initialize_jax_for_gpu as follows (max_utils.py line 243):
def initialize_jax_for_gpu():
"""Jax distributed initialize for GPUs."""
if os.environ.get("JAX_COORDINATOR_IP") is not None:
coordinator_ip = str(os.getenv("JAX_COORDINATOR_IP"))
coordinator_port = str(os.getenv("JAX_COORDINATOR_PORT"))
device_list = {os.getenv("CUDA_VISIBLE_DEVICES")}
if len(device_list) == 0:
device_list = None
jax.distributed.initialize(
coordinator_address=f"{coordinator_ip}:{coordinator_port}",
num_processes=int(os.getenv("NNODES")),
process_id=int(os.getenv("NODE_RANK")),
local_device_ids=device_list,
)
max_logging.log(f"JAX global devices: {jax.devices()}")
When using MaxText with slurm, our jobs only see one GPU per node because jax.distributed assumes one GPU per process when used with slurm (see the Jax docs.
This behavior can be overridden by passing local_device_ids to jax.distributed.initialize, so one way to fix this is to change initialize_jax_for_gpu as follows (max_utils.py line 243): def initialize_jax_for_gpu(): """Jax distributed initialize for GPUs.""" if os.environ.get("JAX_COORDINATOR_IP") is not None: coordinator_ip = str(os.getenv("JAX_COORDINATOR_IP")) coordinator_port = str(os.getenv("JAX_COORDINATOR_PORT")) device_list = {os.getenv("CUDA_VISIBLE_DEVICES")} if len(device_list) == 0: device_list = None jax.distributed.initialize( coordinator_address=f"{coordinator_ip}:{coordinator_port}", num_processes=int(os.getenv("NNODES")), process_id=int(os.getenv("NODE_RANK")), local_device_ids=device_list, ) max_logging.log(f"JAX global devices: {jax.devices()}")
This can probably use more robust error handling.