Open mgbukov opened 4 years ago
We're going to have a hard time debugging this without being able to reproduce it, which might be hard. Does it need a particular MPI setup or does it reproduce in, say, a cloud VM?
I'm not aware of any change in jaxlib that would have intentionally triggered this problem, so we probably need to debug this the hard way.
I think the first step would be to debug what resource we are running out of. Some guesses are
If I were going to try debugging this myself, I might see if I can reproduce the behavior under strace
and see if I can see which system call failed.
@hawkinsp I'd be happy to try out different suggestions. I can try to reproduce this behavior on a different cluster and see if I get it there. It might be important to understand what the observation tells us that the error does not appear for jaxlib=0.1.37
-- the old version can give us a baseline for comparison.
How do I try to reproduce the behavior under the strace
-- I'm not familiar with this.
I checked two more linux platforms, and my mac's osx: the upshot is that 2/3 linux platforms give the same error; the other linux platform and the osx platform do not throw an error. How can I get the details of the faulty linux platforms to share them here?
Is it expected that the pip version of jax shows problems on some linux platforms?
The good news is that to reproduce the error it seems like we don't need a node with a large number of cores. All one needs is to run the minimal code snippet above with more than 24 MPI processes:
mpiexec -np 26 python -W ignore bug.py
Could you share the output of uname -a
and ulimit -a
for the Linux platforms, both those that failed and those that succeeded?
I'm unable to reproduce this problem on a Google Compute Platform cloud VM under Debian 10. I created an n1-standard-64 instance (64 vCPUs) and installed Debian 10. I installed Python 3.7 and mpi4py using apt
and everything seems to work fine with your repro.
If you were able to reproduce this on a reasonably standard linux platform, it would be helpful if you can share details. I need to be able to reproduce this to debug it, and the best case would be that we can reproduce it in some configuration we can both get access to (e.g., on a cloud VM).
Linux cori01 4.12.14-150.47-default #1 SMP Wed Dec 18 15:05:52 UTC 2019 (8162e25) x86_64 x86_64 x86_64 GNU/Linux
core file size (blocks, -c) 0 data seg size (kbytes, -d) unlimited scheduling priority (-e) 0 file size (blocks, -f) unlimited pending signals (-i) 2060260 max locked memory (kbytes, -l) unlimited max memory size (kbytes, -m) unlimited open files (-n) 4096 pipe size (512 bytes, -p) 8 POSIX message queues (bytes, -q) 819200 real-time priority (-r) 0 stack size (kbytes, -s) unlimited cpu time (seconds, -t) unlimited max user processes (-u) 2048 virtual memory (kbytes, -v) unlimited file locks (-x) unlimited
Linux scc1 3.10.0-1062.12.1.el7.x86_64 #1 SMP Tue Feb 4 23:02:59 UTC 2020 x86_64 x86_64 x86_64 GNU/Linux
core file size (blocks, -c) 0 data seg size (kbytes, -d) 20000000 scheduling priority (-e) 0 file size (blocks, -f) unlimited pending signals (-i) 1030652 max locked memory (kbytes, -l) unlimited max memory size (kbytes, -m) unlimited open files (-n) 1024 pipe size (512 bytes, -p) 8 POSIX message queues (bytes, -q) 819200 real-time priority (-r) 0 stack size (kbytes, -s) 8192 cpu time (seconds, -t) unlimited max user processes (-u) 1024 virtual memory (kbytes, -v) unlimited file locks (-x) unlimited
Linux cuiw-ubuntu 4.15.0-96-generic #97-Ubuntu SMP Wed Apr 1 03:25:46 UTC 2020 x86_64 x86_64 x86_64 GNU/Linux
core file size (blocks, -c) 0 data seg size (kbytes, -d) unlimited scheduling priority (-e) 0 file size (blocks, -f) unlimited pending signals (-i) 126673 max locked memory (kbytes, -l) 16384 max memory size (kbytes, -m) unlimited open files (-n) 1024 pipe size (512 bytes, -p) 8 POSIX message queues (bytes, -q) 819200 real-time priority (-r) 0 stack size (kbytes, -s) 8192 cpu time (seconds, -t) unlimited max user processes (-u) 126673 virtual memory (kbytes, -v) unlimited file locks (-x) unlimited
Linux fe001 4.4.0-31-generic #50-Ubuntu SMP Wed Jul 13 00:07:12 UTC 2016 x86_64 x86_64 x86_64 GNU/Linux
core file size (blocks, -c) 0 data seg size (kbytes, -d) unlimited scheduling priority (-e) 0 file size (blocks, -f) unlimited pending signals (-i) 15665 max locked memory (kbytes, -l) 64 max memory size (kbytes, -m) unlimited open files (-n) 1024 pipe size (512 bytes, -p) 8 POSIX message queues (bytes, -q) 819200 real-time priority (-r) 0 stack size (kbytes, -s) 8192 cpu time (seconds, -t) unlimited max user processes (-u) 15665 virtual memory (kbytes, -v) unlimited file locks (-x) unlimited
I think the issue is the value of ulimit -u
, the limit on the number of user processes. Here "user processes" also includes threads, and Jax is creating, for each process spawned by MPI, a threadpool with one thread per CPU core.
e.g., on my n1-standard-64
with an IPython session that has run a trivial JAX computation, I see:
ps -AL | grep 6947 | wc -l
83
If I set ulimit -u 1024
before calling mpirun
, I also see the Resource exhausted
error.
JAX uses threading internally in some fairly fundamental ways, so we're not going to be able to eliminate threading, but we can certainly provide an option to reduce the size of some of the internal threadpools. Or perhaps we can detect this kind of MPI configuration and choose a more appropriate thread pool size automatically.
You might be able to work around the problem for now by raising ulimit -u
, but you might need superuser privileges on your machine to do this.
@hawkinsp I see, so what is the proper way to manually set the number of threads that jax creates? I feel like I should be able to tell it to use a single thread if needed. This is particularly helpful when large simulations are launched and one needs to fit within the available resources.
I currently have
os.environ['MKL_NUM_THREADS']='1' # set number of MKL threads to run in parallel
os.environ['OPENBLAS_NUM_THREADS']='1'
os.environ['OMP_NUM_THREADS']='1'
os.environ["NUM_INTER_THREADS"]="1"
os.environ["NUM_INTRA_THREADS"]="1"
# set XLA threads and parallelism
os.environ["XLA_FLAGS"] = "--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1"
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="0"
but it doesnt see to help.
JAX uses threading internally in some fairly fundamental ways, so we're not going to be able to eliminate threading, but we can certainly provide an option to reduce the size of some of the internal threadpools. Or perhaps we can detect this kind of MPI configuration and choose a more appropriate thread pool size automatically.
I see, and then it must be this intrinsic threading in jax, that was updated when jaxlib=0.1.37
--> jaxlib=0.1.38
so when I ran my code with jaxlib=0.1.37
, the required intrinsic number of threads was smaller so the calculation was able to fit within the memory requirements. I guess this makes sense.
Happily it turns out there's an easy solution to this: the libraries underlying JAX respect CPU affinity values and when creating a threadpool sizes it according to the task's CPU affinity map. So if you tell MPI to assign, say, one core per process via the appropriate options to mpirun
, JAX will size its thread pool appropriately. You'll still have more than one thread, but this should work well enough to solve your problem.
e.g.,
mpirun --cpus-per-proc 1 -np 26 python3 t.py
worked fine for me. (I gather you should probably use --map-by
instead of --cpus-per-proc
these days but I don't know MPI and I leave that part as an exercise for the reader.)
Does that help?
Unfortunately neither of
mpirun --cpus-per-proc 1 -np 26 python t.py
mpirun --map-by 15 -np 26 python t.py
mpirun --map-by ppr:15:core -np 26 python t.py
mpiexec --map-by ppr:15:core -np 26 python t.py
etc. work for me [I have mpi4py
installed via anaconda].
However, the idea is spot on, because the SLURM command srun
has a similar flag which did the job:
srun --ntasks-per-core 1 -n 26 python t.py
@hawkinsp it might be helpful to comment on this somewhere on the jax documentation page.
I have exactly the same issue!
I am training a FLAX model, and after several steps (210 - 230) I get: LLVM ERROR: pthread_create failed: Resource temporarily unavailable.
I am not using mpirun or anything... I just run
CUDA_VISIBLE_DEVICES=0 python train.py
Python packages:
I am running the experiments on:
I even do not know from where the error occurs
jax=0.1.62 jaxlib=0.1.43
I bumped into an issue running multiple processes which call jax in parallel using Open MPI. I was able to distill the error as follows:
I'm running the following code snippet with CPU backend:
in parallel over 26 (independent) processes using
mpi4py
by executing the commandThis miniprogram runs on a big 128 GB compute node with 30 Haswell processors [more on the Haswell node specifics here]. All processors are reserved exclusively for this calculation.
Every time I run this job, a number of random processes terminate with the jax RuntimeError
I also tried distributing the same 26 MPI processes over two and more identical compute nodes and see the same behavior.
Occasionally, I see the same error occur in the
PRNGKey
function.I also found out that this error does NOT occur with
jax=0.1.58, jaxlib=0.1.37
. Does anyone have an idea what might be causing this -- it must be within the commitjaxlib=0.1.37 --> jaxlib=0.1.38
.