Open dionhaefner opened 4 years ago
Hey @dionhaefner ! I had a similar issue: https://github.com/google/jax/issues/743 HTH
We should document the solution we figured out in #743.
Hmm, unfortunately that doesn't seem to do it.
I am using this script for testing:
import jax
@jax.jit
def bench(sa, ct, p):
return sa + ct * p
def run(sa, ct, p):
return bench(sa, ct, p).block_until_ready()
if __name__ == '__main__':
import numpy
size = 10_000_000
s = numpy.random.rand(size)
t = numpy.random.rand(size)
p = numpy.random.rand(size)
for _ in range(100):
run(s, t, p)
Running this gives:
XLA_FLAGS="--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1 inter_op_parallelism_threads=1" time python bench_jax.py
<snip>/lib/python3.7/site-packages/jax/lib/xla_bridge.py:115: UserWarning: No GPU/TPU found, falling back to CPU.
warnings.warn('No GPU/TPU found, falling back to CPU.')
6.64 real 10.59 user 1.98 sys
Note the lower real
than user
time. Looking at the resource monitor, I can see ~180% CPU usage.
I'm having the same issue trying to run jax on an HPC with multiple CPUs. The solution from #743 doesn't work for me either.
I tried setting
os.environ['MKL_NUM_THREADS']='1'
os.environ['OPENBLAS_NUM_THREADS']='1'
os.environ["NUM_INTER_THREADS"]="1"
os.environ["NUM_INTRA_THREADS"]="1"
os.environ["XLA_FLAGS"] = ("--xla_cpu_multi_thread_eigen=false "
"intra_op_parallelism_threads=1")
I'm getting RuntimeError: Resource temporarily unavailable
on the HPC and on my laptop I can see how the code uses more than one CPU.
@dionhaefner did you manage to et around this problem?
And here's the jax mnist example, where I tried to disable multithreading putting the above flags in the beginning of the python script, but it doesn't work :(
Maybe you could try the threadpoolctl
library to set the number of threads for all backends. (BLAS, openmp, MKL, ... that is. I don't mean JAX backends.)
from threadpoolctl import threadpool_limits
[...]
with threadpool_limits(limits=1):
for _ in range(10000):
run(s, t, p)
I ended up setting processor affinity (only works on Unix systems though, and might require special permissions):
$ taskset -c 0 python myscript.py
But won't that limit the script to one specific CPU? And you have to manage manually which instance runs on which CPU.
Yes. It's just a workaround, not really a solution.
I wonder if it's jax or xla (or some other internal package) which spawns the extra threads.
In any case it'd be very useful to have a clear example (maybe even as part of the official jax list of example) on how to get this under control; otherwise allocating resources on hpc clusters becomes a tedious trial-and-error task.
Afaik the threads are spawned by the CPU backend(s), namely openmp, BLAS, MKL and the like. To get that under control is specifically the purpose of threadpoolctl
. If I understand your problem/usecase correctly, it's exactly what you need!
I guess I'm doing something wrong, but threadpoolctl
doesn't seem to work for me on osx. Could you take a quick look at the updated
This code uses 22 threads and up to 500 %CPU (as given in top).
I believe the code XLA generates on CPU doesn't use MKL, OpenBLAS, or the system BLAS, so environment variables related to those libraries are unlikely to have an effect; for BLAS and related operations (e.g. convolutions), it uses an embedded copy of Eigen (really Eigen's Tensor sub-library) and for everything else it generates its own code with LLVM.
--xla_cpu_multi_thread_eigen=false
should ensure that the BLAS library calls use Eigen in single-threaded mode, but non-BLAS code seems to be multi-threaded based on the XLA configuration option intra_op_parallelism_threads
(which might not be wired through as a flag, or at least I can't find it?)
CC @hawkinsp
Oh no, I think I mixed up my anecdotal evidence from numba
with this issue. I'm sorry mgbukov that I led you on the wrong track.
@clemisch no worries, ideas are always welcome!
@jekbradbury I did set both --xla_cpu_multi_thread_eigen=false
and intra_op_parallelism_threads
in mnist_single_thread.txt
above, but there's still some residual multithreading. Are you suggesting that there's something wrong with the flags/the way ti set them in the script, or that there are more/other flags I need to set?
It's going to be almost impossible to completely eliminate threading from JAX; the runtime internally uses multiple threads, e.g., to overlap the Python interpreter with XLA compilation and interpretation.
That said, we can probably avoid threading in the compute-intensive XLA-generated code; XLA itself has support for this although I'm unsure if it's plumbed through the the API surface in a usable way.
A quick note: setting the task affinity map is the correct way to limit JAX's CPU usage at the moment. JAX sizes its main threadpool using this logic: https://github.com/tensorflow/tensorflow/blob/4b2cb67756009dda843c6b56a8b320c8a54373e0/tensorflow/core/platform/default/port.cc#L67
If launching from mpirun
, mpirun
knows how to set task affinities correctly.
I have a use-case where I want to run a JIT compiled function with only a single CPU, using separate threading to handle parallelism (with Dask). At the moment, it looks like there's no way to do this?
Similar to @shoyer I wish to orchestrate many multithreading XLA ops on CPU, except using pmap
instead of dask. Each XLA op is a jitted function containing some BLAS and non-BLAS code. I'd like the BLAS code to be limited to 1 or 2 threads per job.
I've noted this behaviour that doesn't exactly make sense to me.
ncpu=2
os.environ['XLA_FLAGS'] = f"--xla_force_host_platform_device_count={ncpu}"
os.environ["XLA_FLAGS"] = ("--xla_cpu_multi_thread_eigen=false "
"intra_op_parallelism_threads=1")
from jax import local_device_count
print(local_device_count()) #1 instead of 2
The desire is that one can create ncpu
virtual devices and turn off the intra_op
parallelism. However, the suggested method in this issue seems to do neither of these things.
In your example you immediately overwrite the XLA_FLAGS env var. Did you mean to append instead?
Ah that's right! Thanks @dionhaefner
For others struggling with this the following appears to be working well for me:
taskset -c 0 mpiexec -np 1 python /path/to/your/python/file.py
For me, these solutions do not work when I want to execute several python jobs on the same machine. I finally found likwid-pin (https://github.com/RRZE-HPC/likwid) to be working as expected:
likwid-pin -c 0 python /path/to/your/python/file.py
likwid-pin -c 1 python /path/to/your/python/file.py
...
Is there any progress on this issue? It would be really great if Jax offered a way to limit thread spawning besides taskset
. On HPC clusters they frequently limit you to a small number of concurrent threads, which makes multiprocessed jax a non-starter. It's a blocking issue in my group. taskset
works when you have access to the place the processes are spawning, but often that's deep within a library. My experience has been that none of the XLA_FLAGS
do anything to remediate this.
For example, if the code linked above (here) is current, could this default to a flag if provided?
Hi, I would also be interested in ffollow-ups about this issue. It is blocking in my group, too, and it is preventing me from using more than 20 CPUs on a single cluster node. None of the suggested flags work and I cannot use taskset
, as well.
Us too... we have CPU affinity set via the static scheduling in Kubernetes and see a pretty drastic reduction in performance as core count goes up. A single-CPU unit of work on a m5zn.large (4 CPUs) takes 28 minutes and running the same work on a m5zn.xlarge (8 CPUs) takes 45 minutes.
In addition to the CPU affinity, we have these environment variables set:
env:
- name: OMP_NUM_THREADS
value: "1"
- name: XLA_FLAGS
value: --xla_force_host_platform_device_count=1 --xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1
- name: OPENBLAS_NUM_THREADS
value: "1"
- name: MKL_NUM_THREADS
value: "1"
- name: NUM_INTER_THREADS
value: "1"
- name: NUM_INTRA_THREADS
value: "1"
In table form (all m5zn instances): CPUs | concurrency (number of 1 CPU pods) | duration for a pod to finish in minutes |
---|---|---|
4 | 1 | 23 |
4 | 3 | 28 |
8 | 6 | 45 |
Maybe this is more a question for AWS, but why is one CPU on an 8 CPU instance taking ~50% longer to do the work than one CPU on a 4 CPU instance? Are the other CPUs being busy that damaging to performance? (The difference between running one pod and three is also 22%, so maybe indeed that is the case.)
I'm also interested in a straightforward API for this for HPC settings, but also settings where the automatic choice of multiple cores is poor for not so big matrix sizes.
why is one CPU on an 8 CPU instance taking ~50% longer to do the work than one CPU on a 4 CPU instance
@akloss-cibo off topic reponse but memory bandwidth is a constant resource for a socket (under some assumptions) regardless of core count, so your comment is expected as soon as the memory involved spills out of L3 cache; that's the first culprit to check for
so your comment is expected as soon as the memory involved spills out of L3 cache; that's the first culprit to check for A few quotes from AWS: The Nitro Hypervisor allows M5zn instances to deliver performance that is just about indistinguishable from bare metal.
Placement Groups – M5zn instances can be used in Cluster (for low latency and high network throughput), Spread (to keep critical instances separate from each other), and Partition (to reduce correlated failures) placement groups.
and Intel:
Up to 28 CPU cores Multi-socket support (2, 4, 8 CPU)
m5zn are available in sizes with 2, 4, 8, 12, 24, or 48 CPUs... 12, 24, and 48 seem like unusual sizes to me... so what I think I'm gathering is that you think both the 4 and 8 CPU instances are the same number of sockets (my WAG is one) and the socket bandwidth is maxed out by the 6 pods; assuming a 12 CPU instance is also a single socket it presumably would take 90 minutes to complete, but if a 12 CPU instance is actually two sockets, it could do 6 pods of work in 45 minutes. I infer from AWS's "placement groups" that there's potential for a noisy neighbor who is using CPUs on a socket they share with you could reduce your throughput too (and this noisy neighbor could be another one of your own instances)...
We've moved this workload to Nvidia where we're happier with its performance, but it's good to have some idea of what's going on. Thanks for the insight.
I am assembling a single-core CPU performance benchmark for various HPC libraries from the modern Python ecosystem. I would like to include JAX, and first results seem very promising, but I'm failing to restrict it to a single thread. As far as I can tell, there is no corresponding setting in
jax.config
, and it doesn't listen to any of the popular flags (e.g.OMP_NUM_THREADS
).I installed JAX and jaxlib from PyPI on OSX.
Is there any way to pull this off?