google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.33k stars 2.69k forks source link

RuntimeError: Resource temporarily unavailable due to running out of threads (ulimit -u) #2685

Open mgbukov opened 4 years ago

mgbukov commented 4 years ago

jax=0.1.62 jaxlib=0.1.43

I bumped into an issue running multiple processes which call jax in parallel using Open MPI. I was able to distill the error as follows:

I'm running the following code snippet with CPU backend:

from mpi4py import MPI
import jax

seed=7
comm=MPI.COMM_WORLD

print(jax.__version__, seed, comm.Get_rank())

jax.lib.xla_bridge.get_backend().platform
rng = jax.random.PRNGKey(seed)

in parallel over 26 (independent) processes using mpi4py by executing the command

mpirun -np 26 python bug.py

This miniprogram runs on a big 128 GB compute node with 30 Haswell processors [more on the Haswell node specifics here]. All processors are reserved exclusively for this calculation.

Every time I run this job, a number of random processes terminate with the jax RuntimeError

Traceback (most recent call last):
  File "./bug.py", line 11, in <module>
    jax.lib.xla_bridge.get_backend().platform
  File "/global/cfs/cdirs/m3444/.conda/envs/jax-debug/lib/python3.8/site-packages/jax/lib/xla_bridge.py", line 163, in get_backend
    return backend(platform)
  File "/global/cfs/cdirs/m3444/.conda/envs/jax-debug/lib/python3.8/site-packages/jax/lib/xla_bridge.py", line 118, in _get_local_backen
d
    backend = xla_client.get_local_backend(platform)
  File "/global/cfs/cdirs/m3444/.conda/envs/jax-debug/lib/python3.8/site-packages/jaxlib/xla_client.py", line 249, in get_local_backend
    backends = _get_local_backends()
  File "/global/cfs/cdirs/m3444/.conda/envs/jax-debug/lib/python3.8/site-packages/jaxlib/xla_client.py", line 225, in _get_local_backend
s
    backend = factory()
  File "/global/cfs/cdirs/m3444/.conda/envs/jax-debug/lib/python3.8/site-packages/jaxlib/xla_client.py", line 169, in _cpu_backend_facto
ry
    client = _xla.get_cpu_client(asynchronous=True)
RuntimeError: Resource temporarily unavailable

I also tried distributing the same 26 MPI processes over two and more identical compute nodes and see the same behavior.

Occasionally, I see the same error occur in the PRNGKey function.

I also found out that this error does NOT occur with jax=0.1.58, jaxlib=0.1.37. Does anyone have an idea what might be causing this -- it must be within the commit jaxlib=0.1.37 --> jaxlib=0.1.38.

hawkinsp commented 4 years ago

We're going to have a hard time debugging this without being able to reproduce it, which might be hard. Does it need a particular MPI setup or does it reproduce in, say, a cloud VM?

I'm not aware of any change in jaxlib that would have intentionally triggered this problem, so we probably need to debug this the hard way.

I think the first step would be to debug what resource we are running out of. Some guesses are

If I were going to try debugging this myself, I might see if I can reproduce the behavior under strace and see if I can see which system call failed.

mgbukov commented 4 years ago

@hawkinsp I'd be happy to try out different suggestions. I can try to reproduce this behavior on a different cluster and see if I get it there. It might be important to understand what the observation tells us that the error does not appear for jaxlib=0.1.37 -- the old version can give us a baseline for comparison.

How do I try to reproduce the behavior under the strace -- I'm not familiar with this.

mgbukov commented 4 years ago

I checked two more linux platforms, and my mac's osx: the upshot is that 2/3 linux platforms give the same error; the other linux platform and the osx platform do not throw an error. How can I get the details of the faulty linux platforms to share them here?

Is it expected that the pip version of jax shows problems on some linux platforms?

The good news is that to reproduce the error it seems like we don't need a node with a large number of cores. All one needs is to run the minimal code snippet above with more than 24 MPI processes:

mpiexec -np 26 python -W ignore bug.py
hawkinsp commented 4 years ago

Could you share the output of uname -a and ulimit -a for the Linux platforms, both those that failed and those that succeeded?

I'm unable to reproduce this problem on a Google Compute Platform cloud VM under Debian 10. I created an n1-standard-64 instance (64 vCPUs) and installed Debian 10. I installed Python 3.7 and mpi4py using apt and everything seems to work fine with your repro.

If you were able to reproduce this on a reasonably standard linux platform, it would be helpful if you can share details. I need to be able to reproduce this to debug it, and the best case would be that we can reproduce it in some configuration we can both get access to (e.g., on a cloud VM).

mgbukov commented 4 years ago
  1. linux platform (error shows up)

Linux cori01 4.12.14-150.47-default #1 SMP Wed Dec 18 15:05:52 UTC 2019 (8162e25) x86_64 x86_64 x86_64 GNU/Linux

core file size (blocks, -c) 0 data seg size (kbytes, -d) unlimited scheduling priority (-e) 0 file size (blocks, -f) unlimited pending signals (-i) 2060260 max locked memory (kbytes, -l) unlimited max memory size (kbytes, -m) unlimited open files (-n) 4096 pipe size (512 bytes, -p) 8 POSIX message queues (bytes, -q) 819200 real-time priority (-r) 0 stack size (kbytes, -s) unlimited cpu time (seconds, -t) unlimited max user processes (-u) 2048 virtual memory (kbytes, -v) unlimited file locks (-x) unlimited

  1. Linux platform (error shows up)

Linux scc1 3.10.0-1062.12.1.el7.x86_64 #1 SMP Tue Feb 4 23:02:59 UTC 2020 x86_64 x86_64 x86_64 GNU/Linux

core file size (blocks, -c) 0 data seg size (kbytes, -d) 20000000 scheduling priority (-e) 0 file size (blocks, -f) unlimited pending signals (-i) 1030652 max locked memory (kbytes, -l) unlimited max memory size (kbytes, -m) unlimited open files (-n) 1024 pipe size (512 bytes, -p) 8 POSIX message queues (bytes, -q) 819200 real-time priority (-r) 0 stack size (kbytes, -s) 8192 cpu time (seconds, -t) unlimited max user processes (-u) 1024 virtual memory (kbytes, -v) unlimited file locks (-x) unlimited

  1. linux platform (NO error)

Linux cuiw-ubuntu 4.15.0-96-generic #97-Ubuntu SMP Wed Apr 1 03:25:46 UTC 2020 x86_64 x86_64 x86_64 GNU/Linux

core file size (blocks, -c) 0 data seg size (kbytes, -d) unlimited scheduling priority (-e) 0 file size (blocks, -f) unlimited pending signals (-i) 126673 max locked memory (kbytes, -l) 16384 max memory size (kbytes, -m) unlimited open files (-n) 1024 pipe size (512 bytes, -p) 8 POSIX message queues (bytes, -q) 819200 real-time priority (-r) 0 stack size (kbytes, -s) 8192 cpu time (seconds, -t) unlimited max user processes (-u) 126673 virtual memory (kbytes, -v) unlimited file locks (-x) unlimited

  1. linux platform (NO error)

Linux fe001 4.4.0-31-generic #50-Ubuntu SMP Wed Jul 13 00:07:12 UTC 2016 x86_64 x86_64 x86_64 GNU/Linux

core file size (blocks, -c) 0 data seg size (kbytes, -d) unlimited scheduling priority (-e) 0 file size (blocks, -f) unlimited pending signals (-i) 15665 max locked memory (kbytes, -l) 64 max memory size (kbytes, -m) unlimited open files (-n) 1024 pipe size (512 bytes, -p) 8 POSIX message queues (bytes, -q) 819200 real-time priority (-r) 0 stack size (kbytes, -s) 8192 cpu time (seconds, -t) unlimited max user processes (-u) 15665 virtual memory (kbytes, -v) unlimited file locks (-x) unlimited

hawkinsp commented 4 years ago

I think the issue is the value of ulimit -u, the limit on the number of user processes. Here "user processes" also includes threads, and Jax is creating, for each process spawned by MPI, a threadpool with one thread per CPU core.

e.g., on my n1-standard-64 with an IPython session that has run a trivial JAX computation, I see:

ps -AL | grep 6947 | wc -l
83

If I set ulimit -u 1024 before calling mpirun, I also see the Resource exhausted error.

JAX uses threading internally in some fairly fundamental ways, so we're not going to be able to eliminate threading, but we can certainly provide an option to reduce the size of some of the internal threadpools. Or perhaps we can detect this kind of MPI configuration and choose a more appropriate thread pool size automatically.

You might be able to work around the problem for now by raising ulimit -u, but you might need superuser privileges on your machine to do this.

mgbukov commented 4 years ago

@hawkinsp I see, so what is the proper way to manually set the number of threads that jax creates? I feel like I should be able to tell it to use a single thread if needed. This is particularly helpful when large simulations are launched and one needs to fit within the available resources.

I currently have

os.environ['MKL_NUM_THREADS']='1' # set number of MKL threads to run in parallel
os.environ['OPENBLAS_NUM_THREADS']='1'
os.environ['OMP_NUM_THREADS']='1'

os.environ["NUM_INTER_THREADS"]="1"
os.environ["NUM_INTRA_THREADS"]="1"

# set XLA threads and parallelism
os.environ["XLA_FLAGS"] = "--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1"
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="0"

but it doesnt see to help.

mgbukov commented 4 years ago

JAX uses threading internally in some fairly fundamental ways, so we're not going to be able to eliminate threading, but we can certainly provide an option to reduce the size of some of the internal threadpools. Or perhaps we can detect this kind of MPI configuration and choose a more appropriate thread pool size automatically.

I see, and then it must be this intrinsic threading in jax, that was updated when jaxlib=0.1.37 --> jaxlib=0.1.38 so when I ran my code with jaxlib=0.1.37, the required intrinsic number of threads was smaller so the calculation was able to fit within the memory requirements. I guess this makes sense.

hawkinsp commented 4 years ago

Happily it turns out there's an easy solution to this: the libraries underlying JAX respect CPU affinity values and when creating a threadpool sizes it according to the task's CPU affinity map. So if you tell MPI to assign, say, one core per process via the appropriate options to mpirun, JAX will size its thread pool appropriately. You'll still have more than one thread, but this should work well enough to solve your problem.

e.g.,

mpirun --cpus-per-proc 1 -np 26 python3 t.py

worked fine for me. (I gather you should probably use --map-by instead of --cpus-per-proc these days but I don't know MPI and I leave that part as an exercise for the reader.)

Does that help?

mgbukov commented 4 years ago

Unfortunately neither of

mpirun --cpus-per-proc 1 -np 26 python t.py
mpirun --map-by 15 -np 26 python t.py
mpirun --map-by ppr:15:core -np 26 python t.py
mpiexec --map-by ppr:15:core -np 26 python t.py

etc. work for me [I have mpi4py installed via anaconda].

However, the idea is spot on, because the SLURM command srun has a similar flag which did the job:

srun --ntasks-per-core 1 -n 26 python t.py

@hawkinsp it might be helpful to comment on this somewhere on the jax documentation page.

MartinKocour commented 1 year ago

I have exactly the same issue!

I am training a FLAX model, and after several steps (210 - 230) I get: LLVM ERROR: pthread_create failed: Resource temporarily unavailable.

I am not using mpirun or anything... I just run

CUDA_VISIBLE_DEVICES=0 python train.py

Python packages:

I am running the experiments on:

I even do not know from where the error occurs