Open kach opened 2 years ago
Whilst we're at it -- JAX caches a lot of things (e.g. jaxprs) and I've found that this can also contribute to OOM on limited memory machines. (Such as GitHub Actions runners, for running tests.) Ways to clean these up would also be desirable.
FWIW my hugely hacky approach so far has been
def clear_caches():
process = psutil.Process()
if process.memory_info().vms > 4 * 2**30: # >4GB memory usage
for module_name, module in sys.modules.items():
if module_name.startswith("jax"):
for obj_name in dir(module):
obj = getattr(module, obj_name)
if hasattr(obj, "cache_clear"):
obj.cache_clear()
gc.collect()
(which in the context of tests is wrapped in a @pytest.fixture(autouse=True)
.)
I just saw clear_backend()
here and it turns out to be quite useful.
https://github.com/google/jax/blob/7721579700ee5c4a951e72156a0bdac4c9768f43/jax/_src/api.py#L3250-L3264
As @pipme mentioned, https://github.com/google/jax/pull/11462 adds a jax.clear_backends() which also clears all compilation cache for jit functions. Would that works for your case? @kach
@sokrypton I guess you need to call jax.clear_backends()
periodically.
@pipme it doesn't seem to be an issue anymore. deleting my comment. Not sure what changed from last night :D (maybe got a better gpu on colab)
https://github.com/google/jax/pull/12048 will add an API for clearing caches.
@patrick-kidger
with the latest jax (0.3.21) your clear_caches() function nolonger works properly. First it crashes at:
jaxlib.xla_extension.WeakrefLRUCache
Next, when you run the clear_cache a second time (after jax.jit compiling model), it kills google-colab (restart happens). I tracked the problem down to jax.interpreters.partial_eval
Here is my current working code:
def clear_caches():
for module_name, module in sys.modules.items():
if module_name.startswith("jax"):
if module_name not in ["jax.interpreters.partial_eval"]:
for obj_name in dir(module):
obj = getattr(module, obj_name)
if hasattr(obj, "cache_clear"):
try:
obj.cache_clear()
except:
pass
gc.collect()
I think the need for manually cleaning the cache once in a while is not natural. It would be nice to have an internal garbage collector inside Jax that cleans the least recently used objects in the cache (depending on the memory usage by the cache at the time).
I am currently facing a very similar issue. Am I supposed to use jax.clear_backends()
or the solution suggested by @sokrypton? If it's the latter, I believe this should be called periodically similarly to jax.clear_backends()
?
I have a similar issue, where a pmapped function is run over multiple GPUs. In pseudocode:
@partia(jax.pmap, in_axes(0, 0, None))
@jax.value_and_grad
def loss(x: Array, y: Array, n: int):
y_hat = jax.lax.scan(model)(init=x, xs=jax.numpy.arange(n))
l = MSE(y_hat, y)
return l
loss, grad = loss(x_batched, y_batched, n)
the number of iterations in the can loop n doesn't change much, but it changes. n is monotonically increasing so I would like to throw away the old cache to make space for the new one. Here is an example of what the memory usage looks like:
n was increased after 5h ish and if increased again it will crash while trying to allocate cuda memory. The only thing that changed in this example is the value of n. The shapes of x and y are constant, though I would also like them to change. Some sort of clear_cache
function would solve this I believe (and also be very useful in the case of x and y changing shape).
jax.clear_backends()
doesn't seem to work and loss.clear_cache()
runs into an error because its a pmapped function not a jitted one.
Does anyone have a new updated clear_cache() function, the solutions posted with latest jax are not able to resolve memory leaks.... :(
@sokrypton does jax.clear_backends()
work in your case?
I think the need for manually cleaning the cache once in a while is not natural. It would be nice to have an internal garbage collector inside Jax that cleans the least recently used objects in the cache (depending on the memory usage by the cache at the time).
@simitii you are absolutely right. The challenge is that we haven't plumbed enough info out of e.g. XLA executables to be able to determine their size, hence we can't tell the size of a cache entry, or hence the total size of the cache itself. There is an LRU eviction policy for most caches, but it's based on total number of entries rather than size, and there are situations where the number of entries remains small while the memory usage is not.
Clearly that's a surmountable hurdle, but it hasn't yet bubbled to the top of the priority list...
Maybe closing the issue was a little premature until we've heard any success stories...
To any folks who are running into such issues: can you try using jax.clear_caches()
?
Should we try this on the main branch @mattjj ? Or latest released version will suffice?
I think it's only in the latest github HEAD at the moment, not the latest pypi release. Let me know if that's a pain and I should work on updating the pypi release.
I'm waiting for official release to try. As it appears a few things need to be compiled?
In my case, the memory grows from alternating 8GB/1GB to alternating 11GB/3GB.
Using @patrick-kidger and @sokrypton 's method worked well and kept the speed fast, but jax.clear_cache()
didn't.
Another method jax.clear_backends()
worked, but the speed became very slow.
FWIW, I've since given up on trying to make something like this work. At some point my hack above started breaking.
Now I just fully restart the Python process every now and again. (E.g. when running a test suite, I run each file with a separate call to pytest
.) Moreover this seems to have actually produced a speed improvement for me, for some reason.
The jax.clear_cache()
method seems to work for me with jax v0.4.13, though not quite as well as I'd expect. Using it reduces CPU memory usage by quite a bit but there's still more in host memory than I'd expect when pretty much all the work is happening on the GPU.
The jax.clear_backends()
is not showed on the public API documentation, is this expected?
Something mysterious: when I called jax._src.api.clear_backends()
, the program ran into an out-of-memory issue after a while. If I instead commented clear_backends
out, it ran without a problem.
Also, it seems after calling clear_backends()
, jax.lib.xla_bridge.get_backend().live_arrays()
gives wrong information about the arrays on the GPU. The arrays do exist (the corresponding Python objects are not deleted) but are not captured by live_arrays()
, i.e., len(live_arrays())
becomes zero.
My Jax version is 0.4.20.
Is it a bug or do I miss something?
There is currently no way to clear the full compilation class in JAX — something like
f._clear_cache()
, but which applies to all JIT-ted functions that have been compiled so far. This would be useful in situations where multiple high-memory-usage tasks need to be done sequentially. See discussion in https://github.com/google/jax/discussions/10826