jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.65k stars 2.82k forks source link

`device_put_sharded` seems to increase host memory use. #7840

Open fedelebron opened 3 years ago

fedelebron commented 3 years ago

Using the free TPU Colab instances:

import jax.tools.colab_tpu  
jax.tools.colab_tpu.setup_tpu()

import psutil
def available_ram_mb():
  return psutil.virtual_memory().available // (1024**2)

import numpy as np

devices = jax.local_devices()
n_devices = len(devices)

input = np.random.rand(n_devices * 128, 224, 224, 3)
shards = np.split(input, n_devices, axis = 0)

print(f"Before: {available_ram_mb()}MB of RAM available.")
device_input = jax.device_put_sharded(shards, devices)
print(f"After: {available_ram_mb()}MB of RAM available.")

This shows:

Before: 11003MB of RAM available.
After: 9702MB of RAM available.

I expected no significant memory use from the call to device_put_sharded. That is, device_input should be a light handle to on-TPU memory buffers, which are copies of the pre-existing shards. Instead it's using 1.3GB of extra RAM, somewhere (?).

Screen Shot 2021-09-07 at 8 48 00 AM

And in fact doing this repeatedly crashes the Colab kernel OOM in a few seconds:

import numpy as np

devices = jax.local_devices()
n_devices = len(devices)
import gc

for batch in range(30):
  input = np.random.rand(n_devices * 128, 224, 224, 3)
  sharded_batch = np.split(input, n_devices, axis = 0)
  del input
  gc.collect()
  b2 = available_ram_mb()
  device_input = jax.device_put_sharded(sharded_batch, devices)
  b3 = available_ram_mb()
  print(f"Used {b2 - b3} MB of host RAM.")
  del sharded_batch
  del device_input

Results in:

Used 962 MB of host RAM.
Used -366 MB of host RAM.
Used 402 MB of host RAM.
Used 515 MB of host RAM.
Used -74 MB of host RAM.
Used 1579 MB of host RAM.
Used 1630 MB of host RAM.
Used 94 MB of host RAM.
Used -199 MB of host RAM.
Used 75 MB of host RAM.
Used 198 MB of host RAM.
Used -73 MB of host RAM.
Used 440 MB of host RAM.
Used 515 MB of host RAM.
Used 516 MB of host RAM.
Used 368 MB of host RAM.
Used -367 MB of host RAM.
Used 293 MB of host RAM.
Used 589 MB of host RAM.
Used 515 MB of host RAM.
Used 494 MB of host RAM.

And OOM.

rajasekharporeddy commented 1 month ago

Hi @fedelebron

It seems that a fix for this problem was introduced in later JAX versions. Using JAX 0.4.33 on a Colab TPUv2-8, I observed that device_put_sharded's memory usage is now stable and doesn't grow over time. Furthermore, repeated executions of the code did not result in out-of-memory issues.

import jax
import numpy as np

print(jax.__version__, jax.devices())

import psutil
def available_ram_mb():
  return psutil.virtual_memory().available // (1024**2)

devices = jax.local_devices()
n_devices = len(devices)

input = np.random.rand(n_devices * 128, 224, 224, 3)
shards = np.split(input, n_devices, axis = 0)

print(f"Before: {available_ram_mb()}MB of RAM available.")
device_input = jax.device_put_sharded(shards, devices)
print(f"After: {available_ram_mb()}MB of RAM available.")

Output:

0.4.33 [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1), TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1), TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
Before: 335170MB of RAM available.
After: 334585MB of RAM available.
devices = jax.local_devices()
n_devices = len(devices)
import gc

for batch in range(30):
  input = np.random.rand(n_devices * 128, 224, 224, 3)
  sharded_batch = np.split(input, n_devices, axis = 0)
  del input
  gc.collect()
  b2 = available_ram_mb()
  device_input = jax.device_put_sharded(sharded_batch, devices)
  b3 = available_ram_mb()
  print(f"Used {b2 - b3} MB of host RAM.")
  del sharded_batch
  del device_input

Output:

Used 601 MB of host RAM.
Used 589 MB of host RAM.
Used 587 MB of host RAM.
Used 602 MB of host RAM.
Used 600 MB of host RAM.
Used 584 MB of host RAM.
Used 588 MB of host RAM.
Used 598 MB of host RAM.
Used 591 MB of host RAM.
Used 584 MB of host RAM.
Used 590 MB of host RAM.
Used 590 MB of host RAM.
Used 593 MB of host RAM.
Used 589 MB of host RAM.
Used 586 MB of host RAM.
Used 585 MB of host RAM.
Used 588 MB of host RAM.
Used 590 MB of host RAM.
Used 588 MB of host RAM.
Used 588 MB of host RAM.
Used 585 MB of host RAM.
Used 585 MB of host RAM.
Used 590 MB of host RAM.
Used 583 MB of host RAM.
Used 588 MB of host RAM.
Used 616 MB of host RAM.
Used 585 MB of host RAM.
Used 573 MB of host RAM.
Used 585 MB of host RAM.
Used 584 MB of host RAM.

Attaching a gist for reference.

Thank you.