google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.75k stars 2.71k forks source link

Many threads despite setting task affinities #16215

Open cooijmanstim opened 1 year ago

cooijmanstim commented 1 year ago

Description

There have been several issues in the past about limiting threads, and the recommended way to do it is by setting task affinities (see https://github.com/google/jax/issues/1539#issuecomment-613743209). I'm running under slurm which does set task affinities, nevertheless jax allocates many (~25) threads on first call. I'm testing like this:

import os, subprocess as sp

os.environ["MKL_NUM_THREADS"]="1"
os.environ["OPENBLAS_NUM_THREADS"]="1"
os.environ["OMP_NUM_THREADS"]="1"
os.environ["NUM_INTER_THREADS"]="1"
os.environ["NUM_INTRA_THREADS"]="1"
os.environ["XLA_FLAGS"]="--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1 --xla_force_host_platform_device_count=1"
print(os.sched_getaffinity(0))

import jax
print("pre:", int(sp.check_output(f"ls /proc/{os.getpid()}/task | wc -l", shell=True)))
jax.numpy.zeros([])
print("post:", int(sp.check_output(f"ls /proc/{os.getpid()}/task | wc -l", shell=True)))

and getting outputs like

{13, 5}
pre: 2
post: 28

So I have two cores available, but 26 threads are spawned. The machine I'm running on has 40 cores total, so it's not just spawning a thread per core, but it's still a lot of threads. I'm also using a GPU so I wasn't expecting any CPU worker threads to get allocated.

What jax/jaxlib version are you using?

jax 0.4.2, jaxlib 0.4.1

Which accelerator(s) are you using?

GPU

Additional system info

No response

NVIDIA GPU info

+-----------------------------------------------------------------------------+
| NVIDIA-SMI 515.65.01    Driver Version: 515.65.01    CUDA Version: 11.7     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Tesla V100-SXM2...  On   | 00000000:3A:00.0 Off |                    0 |
| N/A   28C    P0    41W / 300W |      0MiB / 32768MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+
hawkinsp commented 1 year ago

Note we fixed a thread leak in #16272 recently. There may be more going on here; I will check.