Open fedelebron opened 3 years ago
Hi @fedelebron
It seems that a fix for this problem was introduced in later JAX versions. Using JAX 0.4.33 on a Colab TPUv2-8, I observed that device_put_sharded
's memory usage is now stable and doesn't grow over time. Furthermore, repeated executions of the code did not result in out-of-memory issues.
import jax
import numpy as np
print(jax.__version__, jax.devices())
import psutil
def available_ram_mb():
return psutil.virtual_memory().available // (1024**2)
devices = jax.local_devices()
n_devices = len(devices)
input = np.random.rand(n_devices * 128, 224, 224, 3)
shards = np.split(input, n_devices, axis = 0)
print(f"Before: {available_ram_mb()}MB of RAM available.")
device_input = jax.device_put_sharded(shards, devices)
print(f"After: {available_ram_mb()}MB of RAM available.")
Output:
0.4.33 [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1), TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1), TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
Before: 335170MB of RAM available.
After: 334585MB of RAM available.
devices = jax.local_devices()
n_devices = len(devices)
import gc
for batch in range(30):
input = np.random.rand(n_devices * 128, 224, 224, 3)
sharded_batch = np.split(input, n_devices, axis = 0)
del input
gc.collect()
b2 = available_ram_mb()
device_input = jax.device_put_sharded(sharded_batch, devices)
b3 = available_ram_mb()
print(f"Used {b2 - b3} MB of host RAM.")
del sharded_batch
del device_input
Output:
Used 601 MB of host RAM.
Used 589 MB of host RAM.
Used 587 MB of host RAM.
Used 602 MB of host RAM.
Used 600 MB of host RAM.
Used 584 MB of host RAM.
Used 588 MB of host RAM.
Used 598 MB of host RAM.
Used 591 MB of host RAM.
Used 584 MB of host RAM.
Used 590 MB of host RAM.
Used 590 MB of host RAM.
Used 593 MB of host RAM.
Used 589 MB of host RAM.
Used 586 MB of host RAM.
Used 585 MB of host RAM.
Used 588 MB of host RAM.
Used 590 MB of host RAM.
Used 588 MB of host RAM.
Used 588 MB of host RAM.
Used 585 MB of host RAM.
Used 585 MB of host RAM.
Used 590 MB of host RAM.
Used 583 MB of host RAM.
Used 588 MB of host RAM.
Used 616 MB of host RAM.
Used 585 MB of host RAM.
Used 573 MB of host RAM.
Used 585 MB of host RAM.
Used 584 MB of host RAM.
Attaching a gist for reference.
Thank you.
Using the free TPU Colab instances:
This shows:
I expected no significant memory use from the call to
device_put_sharded
. That is,device_input
should be a light handle to on-TPU memory buffers, which are copies of the pre-existingshards
. Instead it's using 1.3GB of extra RAM, somewhere (?).And in fact doing this repeatedly crashes the Colab kernel OOM in a few seconds:
Results in:
And OOM.