Open Flamefire opened 3 months ago
I found a small reproducer:
import jax
import gc
for i in range(1000):
print(jax.jit(lambda x, y: x * y)(1,i))
jax.clear_backends()
jax.clear_caches()
gc.collect()
On my machine (208 cores) this fails consistently after 73 iterations (72 is the last output)
Digging (very) deep I did not found any references to PyClient
in Python but the underlying shared_ptr
still had 3 (strong) references remaining.
That is after fixing an issue in the code:
jax
imports jax.lib
for some unknown reason which seems very superflous: https://github.com/google/jax/blob/2be72052ae5e04f11bff61b0133f54997f825a17/jax/__init__.py#L246xla_bridge
into jax.lib
default_backend
and _backends
from jax._src.xla_bridge which results in an extra reference to both that isn't reset by jax.clear_backends()
import jax
from jax._src import xla_bridge as xb
jax._src.api.xb.get_backend()
print(len(xb._backends)) # 1
print(len(jax.lib.xla_bridge._backends)) # 1
jax.clear_backends()
print(len(xb._backends)) #0
print(len(jax.lib.xla_bridge._backends)) #1
Removing the import in either jax/lib/xla_bridge.py
or jax/__init__.py
doesn't immediately destroy the PyClient
instance but a gc.collect()
after clear_backends
does.
Thanks so much for raising this, and for the thorough investigation!
IIUC from your latest comment, it seems like something in our C++ / extension code is breaking reference counting. Does that sound right?
IIUC from your latest comment, it seems like something in our C++ / extension code is breaking reference counting. Does that sound right?
Not necessarily. I experimented with breakpoints on the relevant std::shared_ptr<PyClient>
constructor and copy assignment and found it to be used in e.g. GSPMDSharding
, ClientAndPtr
(and by that through device lists) and arrays (xla::PyArray
)
Especially with the extensive use of functools.lru_cache
(partially hidden in util.cache/memoize
) and weakref_lru_cache
it is rather likely that some (Python) object is in some cache where the underlying C++ object still has a reference to the std::shared_ptr<PyClient>
And as mentioned jax/lib/xla_bridge.py
is also problematic in that regard as it even holds a reference to the (Python) Client
/backend that should be freed by clear_backends
I also found that it makes a difference if run on a machine that has GPUs available or not. The test suite works if GPUs are available although that could be due to having less CPU cores on that machine (96 vs 208).
Also the example code I posted above fails on GPU machines after 4 iterations with
external/xla/xla/stream_executor/cuda/cuda_driver.cc:1544] could not allocate CUDA stream for context 0x1c97f60: CUDA_ERROR_OUT_OF_MEMORY: out of memory
Running with CUDA_VISIBLE_DEVICES=''
it fails after 164 iterations which roughly matches the (relative) difference in cores.
I'm not particularly surprised that clear_backends
doesn't actually work, and indeed that's why it's deprecated from the public API. We can try to fix this, but it's not going to be easy to keep working: it's just too easy for references to creep in.
If you exclude clear_backends_test
, is there still a problem running the test suite?
If you exclude
clear_backends_test
, is there still a problem running the test suite?
Yes it still fails. This particular test was just one I investigated in more detail after running each test separately and counting the number of ThreadPool/backend creations vs destroys and that one stuck it with a difference of 4 while others have 2 or less.
Description
We are compiling and testing jax on an HPC system with a rather large number of cores (208 including Hyper-Threading)
When running the testsuite (
pytest tests
) it fails withRunning the test individually, e.g.
pytest tests/random_lax_test.py
, it succeeds.When excluding the failing test file with
pytest tests --deselect tests/random_lax_test.py
it fails in another test file. I currently tried to exclude the next and the next after but it keeps failing at some point.With
pytest -s tests
I at least get the error:I.e. it failed to create a new thread with EAGAIN, hinting at resource exhaustion which is why I suspect this to be related to the number of available cores.
Attaching GDB I see the failure happening inside the constructor of
OneDnnMatMulReorderVisitor
which creates (another)tsl::thread::ThreadPool
with another 208 threads. It originates inxla::PyClient::Compile
and then throughxla::ifrt::PjRtLoadedExecutable::Create
Checking the construction and destruction of
ThreadPool
instances until the point of failure I see 41565 ThreadPools created and 41559 destroyed. So there might be some leaks but I couldn't tell how/where that could be. Could as well be that the printfs to stderr weren't entirely caught in my output.I see that XLA
OneDnnMatMulReorder(Visitor)
may reuse a passed threadpool if set in thebuild_options
and e.g.PjRtStreamExecutorClient
does add its own ThreadPool to the build options if there isn't one yet but theTfrtCpuClient
used in this context doesn't do that. Maybe that would help.Any ideas how to fix this?
Finally I got to say, that it might as well already be a bug/issue in XLA, not Jax specifically, but reporting it here as it happens with the Jax test suite. Especially as Jax creates
TfrtCpuClient
instance(s) even already at test collection (pytest --collect-only
) which already creates ThreadPool instances. So it is possible that Jax tests are wasting or not properly freeing ThreadPools.System info (python version, jaxlib version, accelerator, etc.)
It looks like it is leaking
PjRtClient
instances created bymake_cpu_client
. Comparing the number of created and destroyed ThreadPool instances across single invocations of tests (e.g.pytest -s tests/clear_backends_test.py
) for most tests there is a difference of 2 which is the 2 ThreadPools inTfrtCpuClient
(named "XLATfrtCpuClient", "XLAEigen")Most notable for
clear_backends_test
it creates 8 ThreadPools and destroys only 4. Checking the test code it creates the Client twice on purpose which are the 4 thread pools leaked. The other 4 are in e.g.OneDnnMatMulReorderVisitor
which get properly cleaned up.