google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.94k stars 2.74k forks source link

Add mechanism to clear full compilation cache #10828

Open kach opened 2 years ago

kach commented 2 years ago

There is currently no way to clear the full compilation class in JAX — something like f._clear_cache(), but which applies to all JIT-ted functions that have been compiled so far. This would be useful in situations where multiple high-memory-usage tasks need to be done sequentially. See discussion in https://github.com/google/jax/discussions/10826

patrick-kidger commented 2 years ago

Whilst we're at it -- JAX caches a lot of things (e.g. jaxprs) and I've found that this can also contribute to OOM on limited memory machines. (Such as GitHub Actions runners, for running tests.) Ways to clean these up would also be desirable.

FWIW my hugely hacky approach so far has been

def clear_caches():
    process = psutil.Process()
    if process.memory_info().vms > 4 * 2**30:  # >4GB memory usage
        for module_name, module in sys.modules.items():
            if module_name.startswith("jax"):
                for obj_name in dir(module):
                    obj = getattr(module, obj_name)
                    if hasattr(obj, "cache_clear"):
                        obj.cache_clear()
        gc.collect()

(which in the context of tests is wrapped in a @pytest.fixture(autouse=True).)

pipme commented 2 years ago

I just saw clear_backend() here and it turns out to be quite useful. https://github.com/google/jax/blob/7721579700ee5c4a951e72156a0bdac4c9768f43/jax/_src/api.py#L3250-L3264

cky9301 commented 2 years ago

As @pipme mentioned, https://github.com/google/jax/pull/11462 adds a jax.clear_backends() which also clears all compilation cache for jit functions. Would that works for your case? @kach

pipme commented 2 years ago

@sokrypton I guess you need to call jax.clear_backends() periodically.

sokrypton commented 2 years ago

@pipme it doesn't seem to be an issue anymore. deleting my comment. Not sure what changed from last night :D (maybe got a better gpu on colab)

mattjj commented 2 years ago

https://github.com/google/jax/pull/12048 will add an API for clearing caches.

sokrypton commented 1 year ago

@patrick-kidger with the latest jax (0.3.21) your clear_caches() function nolonger works properly. First it crashes at: jaxlib.xla_extension.WeakrefLRUCache

Next, when you run the clear_cache a second time (after jax.jit compiling model), it kills google-colab (restart happens). I tracked the problem down to jax.interpreters.partial_eval

Here is my current working code:

def clear_caches():
  for module_name, module in sys.modules.items():
    if module_name.startswith("jax"):
      if module_name not in ["jax.interpreters.partial_eval"]:
        for obj_name in dir(module):
          obj = getattr(module, obj_name)
          if hasattr(obj, "cache_clear"):
            try:
              obj.cache_clear()
            except:
              pass
  gc.collect()
simitii commented 1 year ago

I think the need for manually cleaning the cache once in a while is not natural. It would be nice to have an internal garbage collector inside Jax that cleans the least recently used objects in the cache (depending on the memory usage by the cache at the time).

LukasMut commented 1 year ago

I am currently facing a very similar issue. Am I supposed to use jax.clear_backends() or the solution suggested by @sokrypton? If it's the latter, I believe this should be called periodically similarly to jax.clear_backends()?

OmerRochman commented 1 year ago

I have a similar issue, where a pmapped function is run over multiple GPUs. In pseudocode:

@partia(jax.pmap, in_axes(0, 0, None))
@jax.value_and_grad
def loss(x: Array, y: Array, n: int):
    y_hat = jax.lax.scan(model)(init=x, xs=jax.numpy.arange(n))
    l = MSE(y_hat, y)
    return l

loss, grad = loss(x_batched, y_batched, n)

the number of iterations in the can loop n doesn't change much, but it changes. n is monotonically increasing so I would like to throw away the old cache to make space for the new one. Here is an example of what the memory usage looks like: W B Chart 3_10_2023, 4 32 39 PM

n was increased after 5h ish and if increased again it will crash while trying to allocate cuda memory. The only thing that changed in this example is the value of n. The shapes of x and y are constant, though I would also like them to change. Some sort of clear_cache function would solve this I believe (and also be very useful in the case of x and y changing shape).

jax.clear_backends() doesn't seem to work and loss.clear_cache() runs into an error because its a pmapped function not a jitted one.

sokrypton commented 1 year ago

Does anyone have a new updated clear_cache() function, the solutions posted with latest jax are not able to resolve memory leaks.... :(

mattjj commented 1 year ago

@sokrypton does jax.clear_backends() work in your case?

I think the need for manually cleaning the cache once in a while is not natural. It would be nice to have an internal garbage collector inside Jax that cleans the least recently used objects in the cache (depending on the memory usage by the cache at the time).

@simitii you are absolutely right. The challenge is that we haven't plumbed enough info out of e.g. XLA executables to be able to determine their size, hence we can't tell the size of a cache entry, or hence the total size of the cache itself. There is an LRU eviction policy for most caches, but it's based on total number of entries rather than size, and there are situations where the number of entries remains small while the memory usage is not.

Clearly that's a surmountable hurdle, but it hasn't yet bubbled to the top of the priority list...

mattjj commented 1 year ago

15448 and openxla/xla#2403 will fix this, finally! We needed some C++ changes, so after merging it'll require building a new jaxlib, or waiting until we push new wheels.

mattjj commented 1 year ago

Maybe closing the issue was a little premature until we've heard any success stories...

To any folks who are running into such issues: can you try using jax.clear_caches()?

VolodyaCO commented 1 year ago

Should we try this on the main branch @mattjj ? Or latest released version will suffice?

mattjj commented 1 year ago

I think it's only in the latest github HEAD at the moment, not the latest pypi release. Let me know if that's a pain and I should work on updating the pypi release.

sokrypton commented 1 year ago

I'm waiting for official release to try. As it appears a few things need to be compiled?

a123455392 commented 1 year ago

In my case, the memory grows from alternating 8GB/1GB to alternating 11GB/3GB. Using @patrick-kidger and @sokrypton 's method worked well and kept the speed fast, but jax.clear_cache() didn't. Another method jax.clear_backends() worked, but the speed became very slow.

patrick-kidger commented 1 year ago

FWIW, I've since given up on trying to make something like this work. At some point my hack above started breaking.

Now I just fully restart the Python process every now and again. (E.g. when running a test suite, I run each file with a separate call to pytest.) Moreover this seems to have actually produced a speed improvement for me, for some reason.

f0uriest commented 10 months ago

The jax.clear_cache() method seems to work for me with jax v0.4.13, though not quite as well as I'd expect. Using it reduces CPU memory usage by quite a bit but there's still more in host memory than I'd expect when pretty much all the work is happening on the GPU.

zlsh80826 commented 10 months ago

The jax.clear_backends() is not showed on the public API documentation, is this expected?

pipme commented 9 months ago

Something mysterious: when I called jax._src.api.clear_backends(), the program ran into an out-of-memory issue after a while. If I instead commented clear_backends out, it ran without a problem.

Also, it seems after calling clear_backends(), jax.lib.xla_bridge.get_backend().live_arrays() gives wrong information about the arrays on the GPU. The arrays do exist (the corresponding Python objects are not deleted) but are not captured by live_arrays(), i.e., len(live_arrays()) becomes zero.

My Jax version is 0.4.20.

Is it a bug or do I miss something?