Open jinmingyi1998 opened 1 year ago
Notice that JAX will preallocate 90% of currently-available GPU memory when the first JAX operation is run. JAX will then sub-allocates from it. See https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html
When you say "jax still occupied memory as it running", maybe that's just the preallocation? To verify this, maybe try setting XLA_PYTHON_CLIENT_PREALLOCATE=false
if you don't want JAX to preallocate?
Notice that JAX will preallocate 90% of currently-available GPU memory when the first JAX operation is run. JAX will then sub-allocates from it. See https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html
When you say "jax still occupied memory as it running", maybe that's just the preallocation?
im sure it's not preallocated. i set this env .
process allocate 17g/total 24g GPU memory on 3090 till this process stop.
im developing a inference server. and i want to release all cuda memory after prediction .
Can you print jax.devices()[0].client.live_buffers()
before and after jax.clear_backend()
and see if it made a difference?
One big hammer is [x.delete() for x in jax.devices()[0].client.live_buffers()]
Can you print
jax.devices()[0].client.live_buffers()
before and afterjax.clear_backend()
and see if it made a difference?One big hammer is
[x.delete() for x in jax.devices()[0].client.live_buffers()]
I print this
print(jax.devices()[0].client.live_buffers())
jax.clear_backends()
print('======================================================================================')
print(jax.devices()[0].client.live_buffers())
it has a little effect
before:
# [0] NVIDIA GeForce RTX 3090 | 64°C, 100 %, 333 / 350 W | 17543 / 24576 MB |
after jax.clear_backend()
# [0] NVIDIA GeForce RTX 3090 | 50°C, 0 %, 118 / 350 W | 17539 / 24576 MB |
its output:
...,
[-0.11988804, -0.08419852, -0.25937328, ..., 0.0872755 ,
0.08518561, 0.06646829],
[-0.05617385, -0.07561915, -0.02690049, ..., 0.07006085,
-0.05562846, 0.09706916],
[ 0.1636534 , 0.02435764, 0.48934963, ..., 0.2926692 ,
-0.1360985 , -0.13793673]], dtype=float32)]
======================================================================================
[]
the device arrays deleted, but cuda memory just reduce 4MB
What about [x.delete() for x in jax.devices()[0].client.live_buffers()]
? If you know for sure you are done with existing buffers, and won't use them again, the big hammer is something you can use.
Hi, I know it's been a while, but is there any solution to this problem? My GPU memory is not clearing events when there are no live buffers. Once a code fails due to OOM, I have no option but to restart the code because the memory is never released. That's true even if I restart the backends.
Goal:
I am running alphafold, And I want to clear all cuda memory after model inference. Then I want to find a something like
torch.cuda.empty_cache()
Check for duplicate requests:
I search the doc and issue and discuss, I found
jax.clear_backend()
, (BTW, this method is not in the doc), I Tried to calljax.clear_backends()
andgc.collect()
after prediction. But jax still occupied memory as it running.