Open ChenAo-Phys opened 1 month ago
I'm not familiar with Slurm, but it looks like this is an environment issue. Did you follow the installation instructions in https://jax.readthedocs.io/en/latest/installation.html?
I'm not familiar with Slurm, but it looks like this is an environment issue. Did you follow the installation instructions in https://jax.readthedocs.io/en/latest/installation.html?
I think it's not an issue of installation, because the code can run on a single process if I don't call jax.distributed.initialize()
How Many GPUs do the nodes have here? More than 2? The error
INTERNAL: Failed call to cuDeviceGet: CUDA_ERROR_INVALID_DEVICE: invalid device ordinal
Suggests that cuda is trying to use the wrong device, possibly one that is not exposed.
Possibly it's because the local device Jax initializes on every rank is based on the local rank. If SLURM assigned you GPUs 0,1 then all is good. But if he assigns 2,3 initialization will fail because Jax's SlurmCluster assumes that the devices to be used start from 0.
The case stems from
being called from
Jax should instead use the local process id to index into the cuda visible devices. Slurm usually sets it.
How Many GPUs do the nodes have here? More than 2? The error
INTERNAL: Failed call to cuDeviceGet: CUDA_ERROR_INVALID_DEVICE: invalid device ordinal
Suggests that cuda is trying to use the wrong device, possibly one that is not exposed.
Possibly it's because the local device Jax initializes on every rank is based on the local rank. If SLURM assigned you GPUs 0,1 then all is good. But if he assigns 2,3 initialization will fail because Jax's SlurmCluster assumes that the devices to be used start from 0.
Thanks Filippo! I think this is the problem. But it's still a bit weird because there is still error when I use all GPUs in a node. Instead, it runs when I call jax.distributed.initialize(local_device_ids=[0])
. It seems that all machines think their local_device_ids
are 0.
I did some further tests with the following code
import jax
import jax.numpy as jnp
from jax.sharding import NamedSharding
jax.distributed.initialize(local_device_ids=[0])
global_mesh = jax.sharding.Mesh(jax.devices(), 'x')
pspecs = jax.sharding.PartitionSpec('x')
replicate_pspecs = jax.sharding.PartitionSpec()
sharding = NamedSharding(global_mesh, pspecs)
replicate_sharding = NamedSharding(global_mesh, replicate_pspecs)
@jax.jit
def f():
out = jnp.arange(8)
return jax.lax.with_sharding_constraint(out, replicate_sharding)
y = f()
print(jnp.sum(y))
This works perfectly. But when I change it to return jax.lax.with_sharding_constraint(out, sharding)
, I got error
srun: error: workergpu036: task 0: Segmentation fault (core dumped)
srun: error: workergpu036: task 1: Segmentation fault (core dumped)
(workergpu036 is the node name) It seems the GPUs can't communicate with each other.
Furthermore, when I print jax.devices()
, there is no error and I got [cuda(id=0), cuda(id=1)]
.
I guess the problem is the local_device_ids
somehow changes from 0 to other numbers after jax.distributed.initialize
, so the devices can't access to each other any more. But I'm not familiar with how jax works exactly. What do you think @PhilipVinc ?
Do you have any idea how to solve this issue for now?
I solve this problem after consulting the HPC support of the Flatiron Institute. It's due to some stupid mistakes I made when I submitted jobs. Here I post the answer from the HPC support for other users' reference.
But a bit of clarification based on your allocation, because you're using "--gpus-per-task", and not explicitly changing "--gpu-bind", each task (that is, each of the 2 processes launched by srun) will only have access to 1 GPU each (which will indeed show up as id 0). If you want processes to be able to access GPUs assigned to other tasks, you need to use something like "--gpu-bind=none" or "--gpus" instead of "--gpus-per-task".
jax.distributed.initialize()
works nicely after adding #SBATCH --gpu-bind=none
to my job script.
Description
I'm submitting multi-process jobs on slurm. The job script is
I test with a simple python code, like
But it can't see the devices and raises the error
I have tested the single-process code and it works well, so it should be the problem with multi-process modules. I also tested with many different clusters. The multi-process program works in some clusters and fails in some others. For example, it fails in the largest Juelich cluster in Germany.
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.30 jaxlib: 0.4.30 numpy: 1.26.3 python: 3.9.18 (main, Sep 11 2023, 13:41:44) [GCC 11.2.0] jax.devices (1 total, 1 local): [cuda(id=0)] process_count: 1 platform: uname_result(system='Linux', node='workergpu001', release='6.1.97.1.fi', version='#1 SMP Tue Jul 9 06:21:23 EDT 2024', machine='x86_64')