google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.24k stars 2.68k forks source link

Slurm initialization only supports one device per host #16788

Open Findus23 opened 1 year ago

Findus23 commented 1 year ago

I have access to a HPC cluster with multiple nodes that each have two GPUs. As I want to do computations that require the memory access of many GPUs, I was looking into the multi-host setup.

The HPC cluster uses slurm, so using that for initialisation seems the easiest. I created a simple test script:

import os

import jax

jax.distributed.initialize()
jax.config.update("jax_enable_x64", True)

print(os.environ.get("CUDA_VISIBLE_DEVICES"))
print(jax.devices(), jax.local_devices())

And a slurm job script:

#!/bin/bash
#SBATCH --mail-type=ALL
#SBATCH --nodes=2
#SBATCH --tasks-per-node=1
#SBATCH --exclusive
#SBATCH --job-name=jax-multinode-test
#SBATCH --gpus=4

source $DATA/venv-jax/bin/activate
cd ~/jax-testing/
srun python distributed.py

I would then expect jax.devices() to be an array of 4 GPU devices. But the output is:

0,1
[gpu(id=0), gpu(id=1)] [gpu(id=0)]
0,1
[gpu(id=0), gpu(id=1)] [gpu(id=1)]

So each of the two hosts only contributes one local device.

If I comment out jax.distributed.initialize() and therefore let both hosts do their own thing, both hosts detect both GPUs properly:

0,1
[gpu(id=0), gpu(id=1)] [gpu(id=0), gpu(id=1)]
0,1
[gpu(id=0), gpu(id=1)] [gpu(id=0), gpu(id=1)]

Technically, this is documented here in the function arguments: https://github.com/google/jax/blob/f94104f71a041def61ea5b22676bbbecbfbe0a9b/jax/_src/distributed.py#L147-L149

But I am not sure what the reason for this limitation is as in my experience having two GPUs per host is quite common.

And patching https://github.com/google/jax/blob/f94104f71a041def61ea5b22676bbbecbfbe0a9b/jax/_src/clusters/slurm_cluster.py#L60-L62 to instead return None makes everything work the way one would expect:

0,1
[gpu(id=0), gpu(id=1), gpu(id=2), gpu(id=3)] [gpu(id=0), gpu(id=1)]
0,1
[gpu(id=0), gpu(id=1), gpu(id=2), gpu(id=3)] [gpu(id=2), gpu(id=3)]

(see https://github.com/google/jax/discussions/16789 for another issue I am having with doing this)

vwxyzjn commented 11 months ago

Same thing. I tried multiple jax version including 0.4.14

#!/bin/bash
#SBATCH --job-name=trl
#SBATCH --partition=production-cluster
#SBATCH --gpus-per-task=8
#SBATCH --cpus-per-gpu=10
#SBATCH --ntasks=2 # 1 nodes
#SBATCH --output=slurm/logs/%x_%j.out

srun python jax_test.py
import jax
import os
jax.distributed.initialize()

print(os.environ.get("CUDA_VISIBLE_DEVICES"))

print("jax.device_count()", jax.device_count())
print("jax.local_device_count()", jax.local_device_count())
print("jax.devices()", jax.devices())
0,1,2,3,4,5,6,7
jax.device_count() 2
jax.local_device_count() 1
jax.devices() [StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0), StreamExecutorGpuDevice(id=1, process_index=1, slice_index=1)]
0,1,2,3,4,5,6,7
jax.device_count() 2
jax.local_device_count() 1
jax.devices() [StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0), StreamExecutorGpuDevice(id=1, process_index=1, slice_index=1)]

If I set

  @classmethod
  def get_local_process_id(cls) -> Optional[int]:
    return None
0,1,2,3,4,5,6,7
jax.device_count() 16
jax.local_device_count() 8
jax.devices() [StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0), StreamExecutorGpuDevice(id=1, process_index=0, slice_index=0), StreamExecutorGpuDevice(id=2, process_index=0, slice_index=0), StreamExecutorGpuDevice(id=3, process_index=0, slice_index=0), StreamExecutorGpuDevice(id=4, process_index=0, slice_index=0), StreamExecutorGpuDevice(id=5, process_index=0, slice_index=0), StreamExecutorGpuDevice(id=6, process_index=0, slice_index=0), StreamExecutorGpuDevice(id=7, process_index=0, slice_index=0), StreamExecutorGpuDevice(id=8, process_index=1, slice_index=1), StreamExecutorGpuDevice(id=9, process_index=1, slice_index=1), StreamExecutorGpuDevice(id=10, process_index=1, slice_index=1), StreamExecutorGpuDevice(id=11, process_index=1, slice_index=1), StreamExecutorGpuDevice(id=12, process_index=1, slice_index=1), StreamExecutorGpuDevice(id=13, process_index=1, slice_index=1), StreamExecutorGpuDevice(id=14, process_index=1, slice_index=1), StreamExecutorGpuDevice(id=15, process_index=1, slice_index=1)]
0,1,2,3,4,5,6,7
jax.device_count() 16
jax.local_device_count() 8
jax.devices() [StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0), StreamExecutorGpuDevice(id=1, process_index=0, slice_index=0), StreamExecutorGpuDevice(id=2, process_index=0, slice_index=0), StreamExecutorGpuDevice(id=3, process_index=0, slice_index=0), StreamExecutorGpuDevice(id=4, process_index=0, slice_index=0), StreamExecutorGpuDevice(id=5, process_index=0, slice_index=0), StreamExecutorGpuDevice(id=6, process_index=0, slice_index=0), StreamExecutorGpuDevice(id=7, process_index=0, slice_index=0), StreamExecutorGpuDevice(id=8, process_index=1, slice_index=1), StreamExecutorGpuDevice(id=9, process_index=1, slice_index=1), StreamExecutorGpuDevice(id=10, process_index=1, slice_index=1), StreamExecutorGpuDevice(id=11, process_index=1, slice_index=1), StreamExecutorGpuDevice(id=12, process_index=1, slice_index=1), StreamExecutorGpuDevice(id=13, process_index=1, slice_index=1), StreamExecutorGpuDevice(id=14, process_index=1, slice_index=1), StreamExecutorGpuDevice(id=15, process_index=1, slice_index=1)]
nouiz commented 11 months ago

There is 2 ways of doing multi-GPU on one node. One process that handle all the GPUS, or each GPU have a different process.

When doing multi-node, we end up doing multiple process. Doing one process per GPU is faster in some cases. So this is the recommended way if you script support it.

The current default suppose that if you use slurm, you will tell slow to start 1 process per GPUs. But this isn't what you are doing. Maybe the code can be updated to detect that automatically. But it isn't the case right now.

So I would suggest to use initialize(). But modify your slurm job to have 1 process for each GPU: --tasks-per-node=2 in the original question.

TODO: Lets keep that issue open to verify if we can update initialize() to automatically detect how many GPUs should be used by the process.

vwxyzjn commented 11 months ago

Ok, so I figured out a quick fix. Just add jax.distributed.initialize(local_device_ids=range(8)). Works like a charm.

import jax
import os
jax.distributed.initialize(local_device_ids=range(8))
print("jax.__version__", jax.__version__)

print(os.environ.get("CUDA_VISIBLE_DEVICES"))

print("jax.device_count()", jax.device_count())
print("jax.local_device_count()", jax.local_device_count())
print("jax.devices()", jax.devices())
jax.__version__ 0.4.13
0,1,2,3,4,5,6,7
jax.device_count() 16
jax.local_device_count() 8
jax.devices() [gpu(id=0), gpu(id=1), gpu(id=2), gpu(id=3), gpu(id=4), gpu(id=5), gpu(id=6), gpu(id=7), gpu(id=8), gpu(id=9), gpu(id=10), gpu(id=11), gpu(id=12), gpu(id=13), gpu(id=14), gpu(id=15)]
jax.__version__ 0.4.13
0,1,2,3,4,5,6,7
jax.device_count() 16
jax.local_device_count() 8
jax.devices() [gpu(id=0), gpu(id=1), gpu(id=2), gpu(id=3), gpu(id=4), gpu(id=5), gpu(id=6), gpu(id=7), gpu(id=8), gpu(id=9), gpu(id=10), gpu(id=11), gpu(id=12), gpu(id=13), gpu(id=14), gpu(id=15)]
Findus23 commented 11 months ago

@nouiz You are right, running multiple processes of jax on each host (with each one responsible for just one GPU) is one way to handle this. And indeed with --tasks-per-node=2 jax correctly initializes the two processes to each handle one GPU.

@vwxyzjn Thank you for that idea. Using local_device_ids=range(2) does the exact same thing as the patch I mentioned above, but is of course a lot more elegant (and doesn't require modifying the jax source code).

nouiz commented 11 months ago

@Findus23 do you have a patch to JAX to have it handle this correctly? If so, where it is? It would be good to update JAX to handle this.

Findus23 commented 11 months ago

@nouiz Sorry, by patch I just mean editing https://github.com/google/jax/blob/f94104f71a041def61ea5b22676bbbecbfbe0a9b/jax/_src/clusters/slurm_cluster.py#L60-L62 to instead return None .

But that isn't correct in a general case of course.

I just compared the environment variables in a --tasks-per-node=1 and a --tasks-per-node=2 run and saw that there is a SLURM_NTASKS_PER_NODE. So it could maybe check if that is set to 1 and then pick all GPUs (set get_local_process_id to range(num_gpus)` if nothing else is specified.

But then again maybe that is a bit too much implicit magic and it would be better just to update the documentation to say that both one-gpu-per-process and one-process-per-host are possible and a short example each how to specify them.

tulvgengenr commented 3 months ago

I've run into this problem as well. But I would like to ask, when running distributed computing using jax, which is more elegant, assigning one process to a node to manage all the local devices or assigning multiple processes to a node, each managing one local device?

alexlyttle commented 2 weeks ago

Ok, so I figured out a quick fix. Just add jax.distributed.initialize(local_device_ids=range(8)). Works like a charm.

import jax
import os
jax.distributed.initialize(local_device_ids=range(8))
print("jax.__version__", jax.__version__)

print(os.environ.get("CUDA_VISIBLE_DEVICES"))

print("jax.device_count()", jax.device_count())
print("jax.local_device_count()", jax.local_device_count())
print("jax.devices()", jax.devices())

Thanks for the solutions, I was having the same issues. To automate, why not do the following? This should work for both cases where there is only one device per host and for multiple GPUs per task.

import os
import jax

cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES")
local_device_ids = [int(i) for i in cuda_visible_devices.split(",")]
print(local_device_ids)

jax.distributed.initialize(local_device_ids=local_device_ids)
print("jax.__version__", jax.__version__)
print(jax.device_count())
print(jax.local_device_count())

for a job with the following options,

#SBATCH --nodes 2
#SBATCH --ntasks-per-node 1
#SBATCH --gpus-per-task 2

outputs,

[0, 1]
jax.__version__ 0.4.30
4
2
[0, 1]
jax.__version__ 0.4.30
4
2