jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.44k stars 2.8k forks source link

Unused vmap GPU memory allocation causes RESOURCE_EXHAUSTED for versions >0.4.14 #23548

Open pwithams opened 2 months ago

pwithams commented 2 months ago

Description

Overview

The script below works when using an NVIDIA GPU with Jax version 0.4.14, but after upgrading to 0.4.31 (and trying a few other versions in between) it is triggering the following error:

E0910 20:24:00.097739 38257 pjrt_stream_executor_client.cc:3067] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate X bytes jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate X bytes.

where the value of X ranges from ~5GB (e.g. 4843897104) to 20GB+ depending on the shape of the dls variable (set to 3540 in the script below).

jax<=0.4.14 - no error jax>0.4.14 - error

Not sure if this is a bug or if there is some code/syntax in the example below that is no longer supported in versions > 0.4.14 that is responsible for this behavior.

Allocation vs. pprof usage

The GPU has 6GB of memory and after some trial and error it appears that setting the dls variable to a shape of 1590 succeeds and uses only ~500kB of memory according to pprof (following https://jax.readthedocs.io/en/latest/device_memory_profiling.html), but a shape of 1600 gives an error trying to allocate ~5GB. If pprof is in fact showing GPU memory usage this could suggest memory is being allocated but not used.

jnp.exp removal

Trial and error also showed that removing the jnp.exp calls inside the function m seem to resolve the issue. For example, the script below with dls shape set to 10000 fails trying to allocate 30GB, but removing the jnp.exp calls succeeds and shows as using only ~2MB by pprof.

Script

import jax
import jax.numpy as jnp
from jax import vmap

def wp(y0, ts, rng, tidx_offset):
    t0, t1 = ts[0], ts[-1]
    y = jnp.ones((11, 3)) * (t0 * t1)
    y = jnp.vstack(
        (
            y,
            jnp.ones((71, 3)) * y0,
        )
    )
    y = jnp.roll(y, tidx_offset, axis=0)
    y = y[:71]
    y = y.at[:, 2].set(jnp.abs(y[:, 2] - 0.03) + 0.03)
    return y

def ps(ys, ts, tidx_offset):
    t = jnp.maximum(ts - ts[0], 0)
    t = jnp.hstack(
        (
            t,
            jnp.zeros(71 - 11),
        )
    )
    t = jnp.roll(t, tidx_offset, axis=0)
    t = t[:71]
    ds = jnp.sqrt(jnp.sum((ys[1:, :] - ys[:-1, :]) ** 2, axis=-1))
    d = jnp.cumsum(jnp.hstack((jnp.array([0.0]), ds)), axis=-1)

    s_xyz = jnp.array([0.123, 0.345, 0.456])
    s_xyz = jnp.exp(jnp.array(-2.0)) + s_xyz
    scale = t * jnp.exp(jnp.array(-2.2)) + d * jnp.exp(jnp.array(-2.5)) + 1e-6
    s = jnp.einsum("i,x->ix", scale, s_xyz)

    return s

def m(s, d, d_mirror, rate):
    scale = 0.5 * (15.0 / 75)
    m = (
        scale
        * rate
        / ((2 * jnp.pi) ** (3 / 2) * jnp.prod(s))
        * (
            # removing these two jnp.exp calls appears to resolve the issue
            (jnp.exp(-0.5 * jnp.sum(d**2 / s**2)))
            + (jnp.exp(-0.5 * jnp.sum(d_mirror**2 / s**2)))
        )
    )
    return m

def func(y0, dl, tss, rng, rate, tidx_offset):
    ys = wp(y0, tss, rng, tidx_offset)
    d = ys - dl
    A = jnp.array([[1, 0, 0], [0, 1, 0], [0, 0, -1]])
    ys_mirror = jnp.matmul(ys, A)
    d_mirror = ys_mirror - dl
    scale = 0.5 * (15.0 / 75)
    rate = rate[tidx_offset]
    s = jnp.ones((71, 3)) * y0 * 2.3

    results = vmap(m, in_axes=(0, 0, 0, None))(s, d, d_mirror, rate)
    return results

@jax.jit
def run():
    y0s = jnp.ones(shape=(1, 3))
    # dls shape of ~1600+ fails on 6GB GPU trying to allocate 5GB+
    # dls shape of <1590 succeeds on 6GB GPU and uses only ~476kB memory according to pprof
    # dls shape of 10000 fails trying to allocate 30GB, but passes and only uses ~2MB when removing jnp.exp calls above
    dls = jnp.ones(shape=(3540, 3))
    rates = jnp.ones(shape=(1, 71))
    rngs = jnp.ones(shape=(71, 75, 2), dtype="uint32")
    tss = jnp.ones(shape=(71, 11, 75))
    tidx_offsets = jnp.arange(len(tss))

    output = vmap(
        vmap(
            vmap(
                vmap(func, in_axes=(None, None, 1, 0, None, None)),
                in_axes=(None, None, 0, 0, None, 0),
            ),
            in_axes=(None, 0, None, None, None, None),
        ),
        in_axes=(0, None, None, None, 0, None),
    )(y0s, dls, tss, rngs, rates, tidx_offsets)
    result = jnp.sum(output, axis=(0, 2, 3))
    return result

result = run()
jax.profiler.save_device_memory_profile("memory.prof")
print(result)
print(result.shape)

System info (python version, jaxlib version, accelerator, etc.)

Pip versions:

# jax
jax==0.4.31
jax-cuda12-pjrt==0.4.31
jax-cuda12-plugin==0.4.31
jaxlib==0.4.31
jaxtyping==0.2.34
# nvidia
nvidia-cublas-cu12==12.6.1.4
nvidia-cuda-cupti-cu12==12.6.68
nvidia-cuda-nvcc-cu12==12.6.68
nvidia-cuda-runtime-cu12==12.6.68
nvidia-cudnn-cu12==9.3.0.75
nvidia-cufft-cu12==11.2.6.59
nvidia-cusolver-cu12==11.6.4.69
nvidia-cusparse-cu12==12.5.3.3
nvidia-nccl-cu12==2.22.3
nvidia-nvjitlink-cu12==12.6.68

Output of jax.print_environment_info(), it is running inside a container based on nvidia/cuda:12.3.2-base-ubuntu22.04:

jax:    0.4.31
jaxlib: 0.4.31
numpy:  1.26.4
python: 3.10.12 (main, Jul 29 2024, 16:56:48) [GCC 11.4.0]
jax.devices (1 total, 1 local): [CudaDevice(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='docker-desktop', release='5.15.153.1-microsoft-standard-WSL2', version='#1 SMP Fri Mar 29 23:14:13 UTC 2024', machine='x86_64')

$ nvidia-smi
Tue Sep 10 20:44:01 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 545.23.06              Driver Version: 545.92       CUDA Version: 12.3     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA GeForce RTX 4050 ...    On  | 00000000:01:00.0 Off |                  N/A |
| N/A   43C    P3              11W /  35W |     78MiB /  6141MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|    0   N/A  N/A     46151      C   /python3.10                               N/A      |
+---------------------------------------------------------------------------------------+

Pip versions of latest version that does not show the error (v0.4.14):

# jax 
jax==0.4.14
jaxlib==0.4.14+cuda12.cudnn89
jaxtyping==0.2.23
# nvidia
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvcc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==8.9.2.26
nvidia-cufft-cu12==11.0.2.54
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.18.1
nvidia-nvjitlink-cu12==12.1.105
justinjfu commented 1 month ago

I checked the HLO when using dls=jnp.ones(shape=(10000, 3)) but it does indeed look like some very large tensors are being generated by your program (1 x 10000 x 71 x 75 x 71 x3 ~= 40GB)

ENTRY main.152 {
  constant.27 = f32[] constant(1)
  broadcast.28 = f32[1,71]{1,0} broadcast(constant.27), dimensions={}
  iota.29 = s32[71]{0} iota(), iota_dimension=0
  ...
  constant.15 = f32[3,3]{1,0} constant({ { 1, 0, 0 }, { 0, 1, 0 }, { 0, 0, -1 } })
  dot.95 = f32[1,71,75,71,3]{4,3,2,1,0} dot(scatter.89, constant.15), lhs_contracting_dims={4}, rhs_contracting_dims={0}
  reshape.96 = f32[1,1,71,75,71,3]{5,4,3,2,1,0} reshape(dot.95)
  broadcast.97 = f32[1,1,71,75,71,3]{5,4,3,2,1,0} broadcast(reshape.96), dimensions={0,1,2,3,4,5}
  reshape.98 = f32[1,71,75,71,3]{4,3,2,1,0} reshape(broadcast.97)
  broadcast.99 = f32[1,10000,71,75,71,3]{5,4,3,2,1,0} broadcast(reshape.98), dimensions={0,2,3,4,5}
  subtract.100 = f32[1,10000,71,75,71,3]{5,4,3,2,1,0} subtract(broadcast.99, broadcast.17)
  multiply.132 = f32[1,10000,71,75,71,3]{5,4,3,2,1,0} multiply(subtract.100, subtract.100)
  divide.133 = f32[1,10000,71,75,71,3]{5,4,3,2,1,0} divide(multiply.132, broadcast.4)
  reduce.138 = f32[1,10000,71,75,71]{4,3,2,1,0} reduce(divide.133, constant.25), dimensions={5}, to_apply=region_3.134
  multiply.139 = f32[1,10000,71,75,71]{4,3,2,1,0} multiply(reduce.138, broadcast.2)
  exponential.140 = f32[1,10000,71,75,71]{4,3,2,1,0} exponential(multiply.139)
  add.141 = f32[1,10000,71,75,71]{4,3,2,1,0} add(exponential.131, exponential.140)
  multiply.146 = f32[1,10000,71,75,71]{4,3,2,1,0} multiply(broadcast.145, add.141)
  ROOT reduce.151 = f32[10000,71]{1,0} reduce(multiply.146, constant.25), dimensions={0,2,3}, to_apply=region_4.147
}

After commenting out the two lines containing exp these large tensors are not materialized:

  ...
  constant.12 = f32[] constant(1)
  reduce.27 = f32[1,71]{1,0} reduce(broadcast.6, constant.12), dimensions={2}, to_apply=region_0.23
  constant.1 = f32[] constant(15.7496099)
  broadcast.2 = f32[1,71]{1,0} broadcast(constant.1), dimensions={}
  multiply.28 = f32[1,71]{1,0} multiply(reduce.27, broadcast.2)
  reshape.29 = f32[1,1,71]{2,1,0} reshape(multiply.28)
  broadcast.34 = f32[1,1,71]{2,1,0} broadcast(reshape.29), dimensions={0,1,2}
  reshape.35 = f32[1,71]{1,0} reshape(broadcast.34)
  broadcast.36 = f32[1,71,71]{2,1,0} broadcast(reshape.35), dimensions={0,2}
  divide.37 = f32[1,71,71]{2,1,0} divide(broadcast.33, broadcast.36)
  broadcast.38 = f32[1,10000,71,75,71]{4,3,2,1,0} broadcast(divide.37), dimensions={0,2,4}
  constant.11 = f32[] constant(0)
  ROOT reduce.43 = f32[10000,71]{1,0} reduce(broadcast.38, constant.11), dimensions={0,2,3}, to_apply=region_1.39
}

I'm not sure why thus code runs on Jax <0.4.14... it's possible there's some optimizations being done differently. You can inspect the compiled code yourself using: run.lower().compiler_ir(dialect='hlo').as_hlo_text() (for >=0.4.30) jax.xla_computation(run)().as_hlo_text() (for <0.4.30)

pwithams commented 1 month ago

Thanks for the response. I'm starting to think it is some change in openxla or lower that is responsible rather than jax itself. A few questions:

Does this seem like a bug or just an old edge case not working anymore do you think? When using dls=jnp.ones(shape=(1590, 3)) the program ran successfully and pprof reported ~500kB of memory usage, but increasing to dls=jnp.ones(shape=(1600, 3)) fails trying to allocate ~5GB, which seems like strange behavior.