jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.12k stars 2.76k forks source link

Multi-process GPU jobs fail on Slurm #23452

Open ChenAo-Phys opened 1 month ago

ChenAo-Phys commented 1 month ago

Description

I'm submitting multi-process jobs on slurm. The job script is

#!/bin/bash
#SBATCH -p gpu
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=2
#SBATCH --gpus-per-task=1
#SBATCH --time=0-00:10:00
#SBATCH --mem-per-cpu=32G
#SBATCH --constraint=a100-40gb

module load cuda/12.3.2

conda activate main

srun --cpu-bind=socket python $1

I test with a simple python code, like

import jax
jax.distributed.initialize()
print(jax.devices())

But it can't see the devices and raises the error

2024-09-05 07:57:34.555804: W external/xla/xla/service/platform_util.cc:199] unable to create StreamExecutor for CUDA:1: failed initializing StreamExecutor for CUDA device ordinal 1: INTERNAL: Failed call to cuDeviceGet: CUDA_ERROR_INVALID_DEVICE: invalid device ordinal
Traceback (most recent call last):
  File "/mnt/home/achen1/miniconda3/envs/main/lib/python3.9/site-packages/jax/_src/xla_bridge.py", line 879, in backends
    backend = _init_backend(platform)
  File "/mnt/home/achen1/miniconda3/envs/main/lib/python3.9/site-packages/jax/_src/xla_bridge.py", line 970, in _init_backend
    backend = registration.factory()
  File "/mnt/home/achen1/miniconda3/envs/main/lib/python3.9/site-packages/jax/_src/xla_bridge.py", line 676, in factory
    return xla_client.make_c_api_client(
  File "/mnt/home/achen1/miniconda3/envs/main/lib/python3.9/site-packages/jaxlib/xla_client.py", line 200, in make_c_api_client
    return _xla.get_c_api_client(plugin_name, options, distributed_client)
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: no supported devices found for platform CUDA

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/mnt/home/achen1/Transformer/test/test.py", line 3, in <module>
    print(jax.devices())
  File "/mnt/home/achen1/miniconda3/envs/main/lib/python3.9/site-packages/jax/_src/xla_bridge.py", line 1082, in devices
    return get_backend(backend).devices()
  File "/mnt/home/achen1/miniconda3/envs/main/lib/python3.9/site-packages/jax/_src/xla_bridge.py", line 1016, in get_backend
    return _get_backend_uncached(platform)
  File "/mnt/home/achen1/miniconda3/envs/main/lib/python3.9/site-packages/jax/_src/xla_bridge.py", line 995, in _get_backend_uncached
    bs = backends()
  File "/mnt/home/achen1/miniconda3/envs/main/lib/python3.9/site-packages/jax/_src/xla_bridge.py", line 895, in backends
    raise RuntimeError(err_msg)
RuntimeError: Unable to initialize backend 'cuda': INTERNAL: no supported devices found for platform CUDA (you may need to uninstall the failing plugin package, or set JAX_PLATFORMS=cpu to skip this backend.)
Traceback (most recent call last):
  File "/mnt/home/achen1/miniconda3/envs/main/lib/python3.9/site-packages/jax/_src/xla_bridge.py", line 879, in backends
    backend = _init_backend(platform)
  File "/mnt/home/achen1/miniconda3/envs/main/lib/python3.9/site-packages/jax/_src/xla_bridge.py", line 970, in _init_backend
    backend = registration.factory()
  File "/mnt/home/achen1/miniconda3/envs/main/lib/python3.9/site-packages/jax/_src/xla_bridge.py", line 676, in factory
    return xla_client.make_c_api_client(
  File "/mnt/home/achen1/miniconda3/envs/main/lib/python3.9/site-packages/jaxlib/xla_client.py", line 200, in make_c_api_client
    return _xla.get_c_api_client(plugin_name, options, distributed_client)
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Getting local topologies failed: Error 1: GetKeyValue() timed out with key: cuda:local_topology/cuda/1 and duration: 2m

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/mnt/home/achen1/Transformer/test/test.py", line 3, in <module>
    print(jax.devices())
  File "/mnt/home/achen1/miniconda3/envs/main/lib/python3.9/site-packages/jax/_src/xla_bridge.py", line 1082, in devices
    return get_backend(backend).devices()
  File "/mnt/home/achen1/miniconda3/envs/main/lib/python3.9/site-packages/jax/_src/xla_bridge.py", line 1016, in get_backend
    return _get_backend_uncached(platform)
  File "/mnt/home/achen1/miniconda3/envs/main/lib/python3.9/site-packages/jax/_src/xla_bridge.py", line 995, in _get_backend_uncached
    bs = backends()
  File "/mnt/home/achen1/miniconda3/envs/main/lib/python3.9/site-packages/jax/_src/xla_bridge.py", line 895, in backends
    raise RuntimeError(err_msg)
RuntimeError: Unable to initialize backend 'cuda': INTERNAL: Getting local topologies failed: Error 1: GetKeyValue() timed out with key: cuda:local_topology/cuda/1 and duration: 2m (you may need to uninstall the failing plugin package, or set JAX_PLATFORMS=cpu to skip this backend.)
srun: error: workergpu001: tasks 0-1: Exited with exit code 1

I have tested the single-process code and it works well, so it should be the problem with multi-process modules. I also tested with many different clusters. The multi-process program works in some clusters and fails in some others. For example, it fails in the largest Juelich cluster in Germany.

System info (python version, jaxlib version, accelerator, etc.)

jax: 0.4.30 jaxlib: 0.4.30 numpy: 1.26.3 python: 3.9.18 (main, Sep 11 2023, 13:41:44) [GCC 11.2.0] jax.devices (1 total, 1 local): [cuda(id=0)] process_count: 1 platform: uname_result(system='Linux', node='workergpu001', release='6.1.97.1.fi', version='#1 SMP Tue Jul 9 06:21:23 EDT 2024', machine='x86_64')

superbobry commented 1 month ago

I'm not familiar with Slurm, but it looks like this is an environment issue. Did you follow the installation instructions in https://jax.readthedocs.io/en/latest/installation.html?

ChenAo-Phys commented 1 month ago

I'm not familiar with Slurm, but it looks like this is an environment issue. Did you follow the installation instructions in https://jax.readthedocs.io/en/latest/installation.html?

I think it's not an issue of installation, because the code can run on a single process if I don't call jax.distributed.initialize()

PhilipVinc commented 1 month ago

How Many GPUs do the nodes have here? More than 2? The error

INTERNAL: Failed call to cuDeviceGet: CUDA_ERROR_INVALID_DEVICE: invalid device ordinal

Suggests that cuda is trying to use the wrong device, possibly one that is not exposed.

Possibly it's because the local device Jax initializes on every rank is based on the local rank. If SLURM assigned you GPUs 0,1 then all is good. But if he assigns 2,3 initialization will fail because Jax's SlurmCluster assumes that the devices to be used start from 0.

PhilipVinc commented 1 month ago

The case stems from

https://github.com/google/jax/blob/8feab682097b0949d0504ec0ee73f4637aeb1f57/jax/_src/clusters/slurm_cluster.py#L66

being called from

https://github.com/google/jax/blob/8feab682097b0949d0504ec0ee73f4637aeb1f57/jax/_src/clusters/cluster.py#L90

Jax should instead use the local process id to index into the cuda visible devices. Slurm usually sets it.

ChenAo-Phys commented 1 month ago

How Many GPUs do the nodes have here? More than 2? The error

INTERNAL: Failed call to cuDeviceGet: CUDA_ERROR_INVALID_DEVICE: invalid device ordinal

Suggests that cuda is trying to use the wrong device, possibly one that is not exposed.

Possibly it's because the local device Jax initializes on every rank is based on the local rank. If SLURM assigned you GPUs 0,1 then all is good. But if he assigns 2,3 initialization will fail because Jax's SlurmCluster assumes that the devices to be used start from 0.

Thanks Filippo! I think this is the problem. But it's still a bit weird because there is still error when I use all GPUs in a node. Instead, it runs when I call jax.distributed.initialize(local_device_ids=[0]). It seems that all machines think their local_device_ids are 0.

I did some further tests with the following code

import jax
import jax.numpy as jnp
from jax.sharding import NamedSharding
jax.distributed.initialize(local_device_ids=[0])

global_mesh = jax.sharding.Mesh(jax.devices(), 'x')
pspecs = jax.sharding.PartitionSpec('x')
replicate_pspecs = jax.sharding.PartitionSpec()
sharding = NamedSharding(global_mesh, pspecs)
replicate_sharding = NamedSharding(global_mesh, replicate_pspecs)

@jax.jit
def f():
    out = jnp.arange(8)
    return jax.lax.with_sharding_constraint(out, replicate_sharding)

y = f()
print(jnp.sum(y))

This works perfectly. But when I change it to return jax.lax.with_sharding_constraint(out, sharding), I got error

srun: error: workergpu036: task 0: Segmentation fault (core dumped)
srun: error: workergpu036: task 1: Segmentation fault (core dumped)

(workergpu036 is the node name) It seems the GPUs can't communicate with each other.

Furthermore, when I print jax.devices(), there is no error and I got [cuda(id=0), cuda(id=1)].

I guess the problem is the local_device_ids somehow changes from 0 to other numbers after jax.distributed.initialize, so the devices can't access to each other any more. But I'm not familiar with how jax works exactly. What do you think @PhilipVinc ? Do you have any idea how to solve this issue for now?

ChenAo-Phys commented 1 month ago

I solve this problem after consulting the HPC support of the Flatiron Institute. It's due to some stupid mistakes I made when I submitted jobs. Here I post the answer from the HPC support for other users' reference.

But a bit of clarification based on your allocation, because you're using "--gpus-per-task", and not explicitly changing "--gpu-bind", each task (that is, each of the 2 processes launched by srun) will only have access to 1 GPU each (which will indeed show up as id 0). If you want processes to be able to access GPUs assigned to other tasks, you need to use something like "--gpu-bind=none" or "--gpus" instead of "--gpus-per-task".

jax.distributed.initialize() works nicely after adding #SBATCH --gpu-bind=none to my job script.