jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.99k stars 2.75k forks source link

Limit number of threads when running on CPU #1539

Open dionhaefner opened 4 years ago

dionhaefner commented 4 years ago

I am assembling a single-core CPU performance benchmark for various HPC libraries from the modern Python ecosystem. I would like to include JAX, and first results seem very promising, but I'm failing to restrict it to a single thread. As far as I can tell, there is no corresponding setting in jax.config, and it doesn't listen to any of the popular flags (e.g. OMP_NUM_THREADS).

I installed JAX and jaxlib from PyPI on OSX.

Is there any way to pull this off?

samuela commented 4 years ago

Hey @dionhaefner ! I had a similar issue: https://github.com/google/jax/issues/743 HTH

mattjj commented 4 years ago

We should document the solution we figured out in #743.

dionhaefner commented 4 years ago

Hmm, unfortunately that doesn't seem to do it.

I am using this script for testing:

import jax

@jax.jit
def bench(sa, ct, p):
    return sa + ct * p

def run(sa, ct, p):
    return bench(sa, ct, p).block_until_ready()

if __name__ == '__main__':
    import numpy
    size = 10_000_000
    s = numpy.random.rand(size)
    t = numpy.random.rand(size)
    p = numpy.random.rand(size)

    for _ in range(100):
        run(s, t, p)

Running this gives:

XLA_FLAGS="--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1 inter_op_parallelism_threads=1" time python bench_jax.py
<snip>/lib/python3.7/site-packages/jax/lib/xla_bridge.py:115: UserWarning: No GPU/TPU found, falling back to CPU.
  warnings.warn('No GPU/TPU found, falling back to CPU.')
        6.64 real        10.59 user         1.98 sys

Note the lower real than user time. Looking at the resource monitor, I can see ~180% CPU usage.

mgbukov commented 4 years ago

I'm having the same issue trying to run jax on an HPC with multiple CPUs. The solution from #743 doesn't work for me either.

I tried setting

os.environ['MKL_NUM_THREADS']='1' 
os.environ['OPENBLAS_NUM_THREADS']='1'

os.environ["NUM_INTER_THREADS"]="1"
os.environ["NUM_INTRA_THREADS"]="1"

os.environ["XLA_FLAGS"] = ("--xla_cpu_multi_thread_eigen=false "
                           "intra_op_parallelism_threads=1")

I'm getting RuntimeError: Resource temporarily unavailable on the HPC and on my laptop I can see how the code uses more than one CPU.

@dionhaefner did you manage to et around this problem?

mgbukov commented 4 years ago

And here's the jax mnist example, where I tried to disable multithreading putting the above flags in the beginning of the python script, but it doesn't work :(

mnist_single_thread.txt

clemisch commented 4 years ago

Maybe you could try the threadpoolctl library to set the number of threads for all backends. (BLAS, openmp, MKL, ... that is. I don't mean JAX backends.)


from threadpoolctl import threadpool_limits

[...]

with threadpool_limits(limits=1):
    for _ in range(10000):
        run(s, t, p)
dionhaefner commented 4 years ago

I ended up setting processor affinity (only works on Unix systems though, and might require special permissions):

$ taskset -c 0 python myscript.py
clemisch commented 4 years ago

But won't that limit the script to one specific CPU? And you have to manage manually which instance runs on which CPU.

dionhaefner commented 4 years ago

Yes. It's just a workaround, not really a solution.

mgbukov commented 4 years ago

I wonder if it's jax or xla (or some other internal package) which spawns the extra threads.

In any case it'd be very useful to have a clear example (maybe even as part of the official jax list of example) on how to get this under control; otherwise allocating resources on hpc clusters becomes a tedious trial-and-error task.

clemisch commented 4 years ago

Afaik the threads are spawned by the CPU backend(s), namely openmp, BLAS, MKL and the like. To get that under control is specifically the purpose of threadpoolctl. If I understand your problem/usecase correctly, it's exactly what you need!

mgbukov commented 4 years ago

I guess I'm doing something wrong, but threadpoolctl doesn't seem to work for me on osx. Could you take a quick look at the updated

mnist_single_thread.txt

This code uses 22 threads and up to 500 %CPU (as given in top).

jekbradbury commented 4 years ago

I believe the code XLA generates on CPU doesn't use MKL, OpenBLAS, or the system BLAS, so environment variables related to those libraries are unlikely to have an effect; for BLAS and related operations (e.g. convolutions), it uses an embedded copy of Eigen (really Eigen's Tensor sub-library) and for everything else it generates its own code with LLVM.

--xla_cpu_multi_thread_eigen=false should ensure that the BLAS library calls use Eigen in single-threaded mode, but non-BLAS code seems to be multi-threaded based on the XLA configuration option intra_op_parallelism_threads (which might not be wired through as a flag, or at least I can't find it?)

CC @hawkinsp

clemisch commented 4 years ago

Oh no, I think I mixed up my anecdotal evidence from numba with this issue. I'm sorry mgbukov that I led you on the wrong track.

mgbukov commented 4 years ago

@clemisch no worries, ideas are always welcome!

@jekbradbury I did set both --xla_cpu_multi_thread_eigen=false and intra_op_parallelism_threads in mnist_single_thread.txt above, but there's still some residual multithreading. Are you suggesting that there's something wrong with the flags/the way ti set them in the script, or that there are more/other flags I need to set?

hawkinsp commented 4 years ago

It's going to be almost impossible to completely eliminate threading from JAX; the runtime internally uses multiple threads, e.g., to overlap the Python interpreter with XLA compilation and interpretation.

That said, we can probably avoid threading in the compute-intensive XLA-generated code; XLA itself has support for this although I'm unsure if it's plumbed through the the API surface in a usable way.

hawkinsp commented 4 years ago

A quick note: setting the task affinity map is the correct way to limit JAX's CPU usage at the moment. JAX sizes its main threadpool using this logic: https://github.com/tensorflow/tensorflow/blob/4b2cb67756009dda843c6b56a8b320c8a54373e0/tensorflow/core/platform/default/port.cc#L67

If launching from mpirun, mpirun knows how to set task affinities correctly.

shoyer commented 3 years ago

I have a use-case where I want to run a JIT compiled function with only a single CPU, using separate threading to handle parallelism (with Dask). At the moment, it looks like there's no way to do this?

Joshuaalbert commented 3 years ago

Similar to @shoyer I wish to orchestrate many multithreading XLA ops on CPU, except using pmap instead of dask. Each XLA op is a jitted function containing some BLAS and non-BLAS code. I'd like the BLAS code to be limited to 1 or 2 threads per job.

Joshuaalbert commented 3 years ago

I've noted this behaviour that doesn't exactly make sense to me.

    ncpu=2
    os.environ['XLA_FLAGS'] = f"--xla_force_host_platform_device_count={ncpu}"
    os.environ["XLA_FLAGS"] = ("--xla_cpu_multi_thread_eigen=false "
                               "intra_op_parallelism_threads=1")
    from jax import local_device_count
    print(local_device_count()) #1 instead of 2

The desire is that one can create ncpu virtual devices and turn off the intra_op parallelism. However, the suggested method in this issue seems to do neither of these things.

dionhaefner commented 3 years ago

In your example you immediately overwrite the XLA_FLAGS env var. Did you mean to append instead?

Joshuaalbert commented 3 years ago

Ah that's right! Thanks @dionhaefner

SimonBiggs commented 3 years ago

For others struggling with this the following appears to be working well for me:

taskset -c 0 mpiexec -np 1 python /path/to/your/python/file.py

From https://stackoverflow.com/a/41143396/3912576

h-larsson commented 2 years ago

For me, these solutions do not work when I want to execute several python jobs on the same machine. I finally found likwid-pin (https://github.com/RRZE-HPC/likwid) to be working as expected:

likwid-pin -c 0 python /path/to/your/python/file.py
likwid-pin -c 1  python /path/to/your/python/file.py
...
samlobel commented 1 year ago

Is there any progress on this issue? It would be really great if Jax offered a way to limit thread spawning besides taskset. On HPC clusters they frequently limit you to a small number of concurrent threads, which makes multiprocessed jax a non-starter. It's a blocking issue in my group. taskset works when you have access to the place the processes are spawning, but often that's deep within a library. My experience has been that none of the XLA_FLAGS do anything to remediate this.

For example, if the code linked above (here) is current, could this default to a flag if provided?

alucantonio commented 1 year ago

Hi, I would also be interested in ffollow-ups about this issue. It is blocking in my group, too, and it is preventing me from using more than 20 CPUs on a single cluster node. None of the suggested flags work and I cannot use taskset, as well.

akloss-cibo commented 1 year ago

Us too... we have CPU affinity set via the static scheduling in Kubernetes and see a pretty drastic reduction in performance as core count goes up. A single-CPU unit of work on a m5zn.large (4 CPUs) takes 28 minutes and running the same work on a m5zn.xlarge (8 CPUs) takes 45 minutes.

In addition to the CPU affinity, we have these environment variables set:

        env:
        - name: OMP_NUM_THREADS
          value: "1"
        - name: XLA_FLAGS
          value: --xla_force_host_platform_device_count=1 --xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1
        - name: OPENBLAS_NUM_THREADS
          value: "1"
        - name: MKL_NUM_THREADS
          value: "1"
        - name: NUM_INTER_THREADS
          value: "1"
        - name: NUM_INTRA_THREADS
          value: "1"
In table form (all m5zn instances): CPUs concurrency (number of 1 CPU pods) duration for a pod to finish in minutes
4 1 23
4 3 28
8 6 45

Maybe this is more a question for AWS, but why is one CPU on an 8 CPU instance taking ~50% longer to do the work than one CPU on a 4 CPU instance? Are the other CPUs being busy that damaging to performance? (The difference between running one pod and three is also 22%, so maybe indeed that is the case.)

maedoc commented 7 months ago

I'm also interested in a straightforward API for this for HPC settings, but also settings where the automatic choice of multiple cores is poor for not so big matrix sizes.

why is one CPU on an 8 CPU instance taking ~50% longer to do the work than one CPU on a 4 CPU instance

@akloss-cibo off topic reponse but memory bandwidth is a constant resource for a socket (under some assumptions) regardless of core count, so your comment is expected as soon as the memory involved spills out of L3 cache; that's the first culprit to check for

akloss-cibo commented 7 months ago

so your comment is expected as soon as the memory involved spills out of L3 cache; that's the first culprit to check for A few quotes from AWS: The Nitro Hypervisor allows M5zn instances to deliver performance that is just about indistinguishable from bare metal.

Placement Groups – M5zn instances can be used in Cluster (for low latency and high network throughput), Spread (to keep critical instances separate from each other), and Partition (to reduce correlated failures) placement groups.

and Intel:

Up to 28 CPU cores Multi-socket support (2, 4, 8 CPU)

m5zn are available in sizes with 2, 4, 8, 12, 24, or 48 CPUs... 12, 24, and 48 seem like unusual sizes to me... so what I think I'm gathering is that you think both the 4 and 8 CPU instances are the same number of sockets (my WAG is one) and the socket bandwidth is maxed out by the 6 pods; assuming a 12 CPU instance is also a single socket it presumably would take 90 minutes to complete, but if a 12 CPU instance is actually two sockets, it could do 6 pods of work in 45 minutes. I infer from AWS's "placement groups" that there's potential for a noisy neighbor who is using CPUs on a socket they share with you could reduce your throughput too (and this noisy neighbor could be another one of your own instances)...

We've moved this workload to Nvidia where we're happier with its performance, but it's good to have some idea of what's going on. Thanks for the insight.