AI-Hypercomputer / maxtext

A simple, performant and scalable Jax LLM!
Apache License 2.0
1.53k stars 293 forks source link

Cannot see multiple GPUs when using Slurm (with proposed fix) #865

Open gabeweisz opened 2 months ago

gabeweisz commented 2 months ago

When using MaxText with slurm, our jobs only see one GPU per node because jax.distributed assumes one GPU per process when used with slurm (see the Jax docs.

This behavior can be overridden by passing local_device_ids to jax.distributed.initialize, so one way to fix this is to change initialize_jax_for_gpu as follows (max_utils.py line 243): def initialize_jax_for_gpu(): """Jax distributed initialize for GPUs.""" if os.environ.get("JAX_COORDINATOR_IP") is not None: coordinator_ip = str(os.getenv("JAX_COORDINATOR_IP")) coordinator_port = str(os.getenv("JAX_COORDINATOR_PORT")) device_list = {os.getenv("CUDA_VISIBLE_DEVICES")} if len(device_list) == 0: device_list = None jax.distributed.initialize( coordinator_address=f"{coordinator_ip}:{coordinator_port}", num_processes=int(os.getenv("NNODES")), process_id=int(os.getenv("NODE_RANK")), local_device_ids=device_list, ) max_logging.log(f"JAX global devices: {jax.devices()}")

This can probably use more robust error handling.