Closed maxosmith closed 1 year ago
My guess is that this is an issue with the SLURM config or how you call it. What is the slurm command that you use?
nvidia-smi show all the GPUs visibles. It doesn't mean you have access to them from memory.
Did you see this documentation: https://jax.readthedocs.io/en/latest/multi_process.html ?
My SLURM command was:
srun --pty --gres=gpu:1 --cpus-per-gpu=4 --mem-per-cpu=10g --time=0-01:00 /bin/bash
I did miss that document, I'll give it a pass now, cheers!
Ah, that document is about pmap
, I'm not trying to distribute the workload in that nature. I'm trying to have a learner node have a GPU, while actor nodes do not need said GPU.
If you request only 1 node with 1 GPUs and many CPUs, you can create a bash script that dispatch like this:
CUDA_VISIBLE_DIVICES= python actor.py & # As many time as needed
python learner.py
The first line will hide the GPU from the process. This work for all software, not just JAX.
Thanks for the reply.
In my example script that I provided, I'm dispatching two PythonProcess that use this environment variable setting:
"test_cpu": PythonProcess(
env={
"CUDA_VISIBLE_DEVICES": "",
"JAX_PLATFORM_NAME": "cpu",
}
),
"test_gpu": PythonProcess(
env={
"CUDA_VISIBLE_DEVICES": "0",
"XLA_PYTHON_CLIENT_MEM_FRACTION": ".2",
"XLA_PYTHON_CLIENT_PREALLOCATE": "false",
"JAX_PLATFORM_NAME": "gpu",
}
),
However, if I try and spawn two processes that can both see device 0, only the first process is seeing the device. Does that make sense?
| However, if I try and spawn two processes that can both see device 0, only the first process is seeing the device. Does that make sense?
I think the issue isn't with JAX. It is probably related to your scheduler.
I suppose if you spawn two test_gpu process and make sure they are on the same node.
Can you print this in both process? print(os.environ.get('CUDA_VISIBLE_DEVICES'))
I do not know launchpad. So I can't help much here. Why do you use that? Can you give me the full output of nvidia-smi? It is possible that the GPU is configured to be usable by only 1 process. Sometimes clusters are configured with that setup. This could explain your issue.
$python launchpad_gpu_test.py
Local
0
Launchpad
[test_cpu/0] I1220 18:34:44.112143 23450219082752 courier_utils.py:120] Binding: run
[test_gpu/0] I1220 18:34:44.113588 22711702426624 courier_utils.py:120] Binding: run
[test_gpu/1] I1220 18:34:44.114040 23041964385280 courier_utils.py:120] Binding: run
[test_cpu/0] I1220 18:34:44.114264 23450219082752 courier_utils.py:120] Binding: test
[test_gpu/0] I1220 18:34:44.114415 22711702426624 courier_utils.py:120] Binding: test
[test_gpu/1] I1220 18:34:44.114561 23041964385280 courier_utils.py:120] Binding: test
[test_gpu/1] 0
[test_gpu/0] 0
[test_cpu/0]
Works as expected. GPU nodes print "0", and the local driver program also prints "0". The restricted node does not see the device.
However, if I add print("Devices: ", jax.devices())
right after all of the print(os.environ.get('CUDA_VISIBLE_DEVICES'))
, all of the processes correctly print their environment variables but all use CPUs.
If I prevent the driver program from binding to the GPU, the processes outputs are:
...
[test_gpu/1] 0
[test_gpu/1] Devices: [CpuDevice(id=0)]
[test_gpu/0] 0
[test_gpu/0] Devices: [StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)]
[test_cpu/0]
[test_cpu/0] Devices: [CpuDevice(id=0)]
I'll submit a ticket to the supercomputer team to see if the suggested limitation is in place.
Thanks for being so interactive Frédéric, I was surprised to see your name pop up here, but I remember you from our rare interactions back at Pavillon André-Aisenstadt in what 2016? :)
I've heard back from the supercomputer team and the devices are indeed set to "exclusive process" mode, which is undoubtedly the issue I'm encountering. Thanks again for all the help.
Description
I am having difficulty getting Jax to share a GPU backend across subprocesses on a SLURM job. The experienced behavior is that whichever process binds with the GPU first is able to work correctly, and all other processes cannot a GPU backend.
Below is a minimal example. In it I've experimented with the main process and varying amounts of subprocesses attempting to bind with the single GPU. I've tried various permutations of XLA flags for memory preallocation or visible devices.
Both GPU nodes have the same env variables, which are:
Additional debugging
If you further print out:
Both have the same settings:
As far as I can tell, all of the system's settings are the same during the hand-off to XLA.
If you add this to the nodes you can see that cuda backend is missing from the platform processing
If you also add:
They're both
None
, prompting all of the backend factories to be run.What jax/jaxlib version are you using?
Which accelerator(s) are you using?
GPU
Additional system info
Red Hat Enterprise Linux 8.4 (Ootpa)
NVIDIA GPU info
NVIDIA-SMI 510.73.08 Driver Version: 510.73.08 CUDA Version: 11.6
NVIDIA A40