google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.16k stars 2.67k forks source link

Test aborting due to thread limit exhaustion #20660

Open Flamefire opened 3 months ago

Flamefire commented 3 months ago

Description

We are compiling and testing jax on an HPC system with a rather large number of cores (208 including Hyper-Threading)

When running the testsuite (pytest tests) it fails with

Fatal Python error: Aborted

Running the test individually, e.g. pytest tests/random_lax_test.py, it succeeds.

When excluding the failing test file with pytest tests --deselect tests/random_lax_test.py it fails in another test file. I currently tried to exclude the next and the next after but it keeps failing at some point.

With pytest -s tests I at least get the error:

F external/tsl/tsl/platform/default/env.cc:74] Check failed: ret == 0 (11 vs. 0)Thread tf_XLACpuCompile creation via pthread_create() failed.

I.e. it failed to create a new thread with EAGAIN, hinting at resource exhaustion which is why I suspect this to be related to the number of available cores.

Attaching GDB I see the failure happening inside the constructor of OneDnnMatMulReorderVisitor which creates (another) tsl::thread::ThreadPool with another 208 threads. It originates in xla::PyClient::Compile and then through xla::ifrt::PjRtLoadedExecutable::Create

Checking the construction and destruction of ThreadPool instances until the point of failure I see 41565 ThreadPools created and 41559 destroyed. So there might be some leaks but I couldn't tell how/where that could be. Could as well be that the printfs to stderr weren't entirely caught in my output.

I see that XLA OneDnnMatMulReorder(Visitor) may reuse a passed threadpool if set in the build_options and e.g. PjRtStreamExecutorClient does add its own ThreadPool to the build options if there isn't one yet but the TfrtCpuClient used in this context doesn't do that. Maybe that would help.

Any ideas how to fix this?

Finally I got to say, that it might as well already be a bug/issue in XLA, not Jax specifically, but reporting it here as it happens with the Jax test suite. Especially as Jax creates TfrtCpuClient instance(s) even already at test collection (pytest --collect-only) which already creates ThreadPool instances. So it is possible that Jax tests are wasting or not properly freeing ThreadPools.

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.25.dev20240409
jaxlib: 0.4.25.dev20240408
numpy:  1.25.1
python: 3.11.3 (main, Mar 27 2024, 14:52:01) [GCC 12.3.0]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='n1004', release='4.18.0-425.19.2.el8_7.x86_64', version='#1 SMP Fri Mar 17 01:52:38 EDT 2023', machine='x86_64')

It looks like it is leaking PjRtClient instances created by make_cpu_client. Comparing the number of created and destroyed ThreadPool instances across single invocations of tests (e.g. pytest -s tests/clear_backends_test.py) for most tests there is a difference of 2 which is the 2 ThreadPools in TfrtCpuClient (named "XLATfrtCpuClient", "XLAEigen")

Most notable for clear_backends_test it creates 8 ThreadPools and destroys only 4. Checking the test code it creates the Client twice on purpose which are the 4 thread pools leaked. The other 4 are in e.g. OneDnnMatMulReorderVisitor which get properly cleaned up.

Flamefire commented 3 months ago

I found a small reproducer:

import jax
import gc

for i in range(1000):
    print(jax.jit(lambda x, y: x * y)(1,i))
    jax.clear_backends()
    jax.clear_caches()
    gc.collect()

On my machine (208 cores) this fails consistently after 73 iterations (72 is the last output)

Digging (very) deep I did not found any references to PyClient in Python but the underlying shared_ptr still had 3 (strong) references remaining.

That is after fixing an issue in the code:

import jax
from jax._src import xla_bridge as xb

jax._src.api.xb.get_backend()
print(len(xb._backends))                        # 1
print(len(jax.lib.xla_bridge._backends)) # 1

jax.clear_backends()
print(len(xb._backends))                        #0
print(len(jax.lib.xla_bridge._backends)) #1

Removing the import in either jax/lib/xla_bridge.py or jax/__init__.py doesn't immediately destroy the PyClient instance but a gc.collect() after clear_backends does.

mattjj commented 3 months ago

Thanks so much for raising this, and for the thorough investigation!

IIUC from your latest comment, it seems like something in our C++ / extension code is breaking reference counting. Does that sound right?

Flamefire commented 3 months ago

IIUC from your latest comment, it seems like something in our C++ / extension code is breaking reference counting. Does that sound right?

Not necessarily. I experimented with breakpoints on the relevant std::shared_ptr<PyClient> constructor and copy assignment and found it to be used in e.g. GSPMDSharding, ClientAndPtr (and by that through device lists) and arrays (xla::PyArray)

Especially with the extensive use of functools.lru_cache (partially hidden in util.cache/memoize) and weakref_lru_cache it is rather likely that some (Python) object is in some cache where the underlying C++ object still has a reference to the std::shared_ptr<PyClient>

And as mentioned jax/lib/xla_bridge.py is also problematic in that regard as it even holds a reference to the (Python) Client/backend that should be freed by clear_backends

I also found that it makes a difference if run on a machine that has GPUs available or not. The test suite works if GPUs are available although that could be due to having less CPU cores on that machine (96 vs 208).

Also the example code I posted above fails on GPU machines after 4 iterations with

external/xla/xla/stream_executor/cuda/cuda_driver.cc:1544] could not allocate CUDA stream for context 0x1c97f60: CUDA_ERROR_OUT_OF_MEMORY: out of memory

Running with CUDA_VISIBLE_DEVICES='' it fails after 164 iterations which roughly matches the (relative) difference in cores.

hawkinsp commented 3 months ago

I'm not particularly surprised that clear_backends doesn't actually work, and indeed that's why it's deprecated from the public API. We can try to fix this, but it's not going to be easy to keep working: it's just too easy for references to creep in.

If you exclude clear_backends_test, is there still a problem running the test suite?

Flamefire commented 2 months ago

If you exclude clear_backends_test, is there still a problem running the test suite?

Yes it still fails. This particular test was just one I investigated in more detail after running each test separately and counting the number of ThreadPool/backend creations vs destroys and that one stuck it with a difference of 4 while others have 2 or less.