jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.51k stars 2.8k forks source link

Clear Cuda Cache #13330

Open jinmingyi1998 opened 1 year ago

jinmingyi1998 commented 1 year ago

Goal:

I am running alphafold, And I want to clear all cuda memory after model inference. Then I want to find a something like torch.cuda.empty_cache()

Check for duplicate requests:

I search the doc and issue and discuss, I found jax.clear_backend() , (BTW, this method is not in the doc), I Tried to call jax.clear_backends() and gc.collect() after prediction. But jax still occupied memory as it running.

zhangqiaorjc commented 1 year ago

Notice that JAX will preallocate 90% of currently-available GPU memory when the first JAX operation is run. JAX will then sub-allocates from it. See https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html

When you say "jax still occupied memory as it running", maybe that's just the preallocation? To verify this, maybe try setting XLA_PYTHON_CLIENT_PREALLOCATE=false if you don't want JAX to preallocate?

jinmingyi1998 commented 1 year ago

Notice that JAX will preallocate 90% of currently-available GPU memory when the first JAX operation is run. JAX will then sub-allocates from it. See https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html

When you say "jax still occupied memory as it running", maybe that's just the preallocation?

im sure it's not preallocated. i set this env .

process allocate 17g/total 24g GPU memory on 3090 till this process stop.

im developing a inference server. and i want to release all cuda memory after prediction .

zhangqiaorjc commented 1 year ago

Can you print jax.devices()[0].client.live_buffers() before and after jax.clear_backend() and see if it made a difference?

One big hammer is [x.delete() for x in jax.devices()[0].client.live_buffers()]

jinmingyi1998 commented 1 year ago

Can you print jax.devices()[0].client.live_buffers() before and after jax.clear_backend() and see if it made a difference?

One big hammer is [x.delete() for x in jax.devices()[0].client.live_buffers()]

I print this

print(jax.devices()[0].client.live_buffers())
jax.clear_backends()
print('======================================================================================')
print(jax.devices()[0].client.live_buffers())

it has a little effect

before:
# [0] NVIDIA GeForce RTX 3090 | 64°C, 100 %,  333 / 350 W | 17543 / 24576 MB |
after jax.clear_backend()
# [0] NVIDIA GeForce RTX 3090 | 50°C,   0 %,  118 / 350 W | 17539 / 24576 MB |

its output:

 ...,
[-0.11988804, -0.08419852, -0.25937328, ...,  0.0872755 ,
0.08518561,  0.06646829],
[-0.05617385, -0.07561915, -0.02690049, ...,  0.07006085,
-0.05562846,  0.09706916],
[ 0.1636534 ,  0.02435764,  0.48934963, ...,  0.2926692 ,
-0.1360985 , -0.13793673]], dtype=float32)]
======================================================================================
[]

the device arrays deleted, but cuda memory just reduce 4MB

zhangqiaorjc commented 1 year ago

What about [x.delete() for x in jax.devices()[0].client.live_buffers()]? If you know for sure you are done with existing buffers, and won't use them again, the big hammer is something you can use.

benjaminvatterj commented 9 months ago

Hi, I know it's been a while, but is there any solution to this problem? My GPU memory is not clearing events when there are no live buffers. Once a code fails due to OOM, I have no option but to restart the code because the memory is never released. That's true even if I restart the backends.