jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.98k stars 2.75k forks source link

Operator `jax.lax.gather` consumes a lot of memory on CPU #23457

Open daskol opened 2 weeks ago

daskol commented 2 weeks ago

Description

I'm working with large embedding layers like 3'000'000x300 or even bigger on CPU-only devices. I noticed that jax takes much more memory than expected.

import os

os.environ['JAX_PLATFORM_NAME'] = 'cpu'  # Make sure JAX use only CPU.

import jax
import jax.numpy as jnp
import pytest

SHAPE = (3_000_000, 300)

@pytest.mark.parametrize('initializer', ['zeros', 'uniform'])
def test_index(initializer: str):
    ix = jnp.array([1, 2])
    match initializer:
        case 'zeros':
            xs = jnp.zeros(SHAPE)
        case 'uniform':
            key = jax.random.PRNGKey(42)
            xs = jax.random.uniform(key, (3_000_000, 300))

    def fn(ix: jax.Array) -> jax.Array:
        return xs[ix]

    jax.jit(fn)(ix)

With the code snippet above, one can reproduce the issues as folows.

$ /usr/bin/time -f 'RSS: %M kB' -- pytest -v 'embed_test.py::test_index[zeros]'
RSS: 3827568 kB
$ /usr/bin/time -f 'RSS: %M kB' -- pytest -v 'embed_test.py::test_index[uniform]'
RSS: 38990788 kB

One can see that memory usage differs in about 10 times while jitted function essentially gets a single item from a 2D array. This operation is common in embedding layers.

System info (python version, jaxlib version, accelerator, etc.)

Issue is reproduced across two jax versions on two different hosts.

$ python -c 'import jax; jax.print_environment_info()'
jax:    0.4.31
jaxlib: 0.4.31
numpy:  2.0.1
python: 3.12.5 (main, Aug  9 2024, 08:20:41) [GCC 14.2.1 20240805]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='f3647874f8f3', release='6.10.6-arch1-1', version='#1 SMP PREEMPT_DYNAMIC Mon, 19 Aug 2024 17:02:39 +0000', machine='x86_64')
$ python -c 'import jax; jax.print_environment_info()'
jax:    0.4.30
jaxlib: 0.4.30
numpy:  2.0.0
python: 3.12.3 (main, Apr 10 2024, 05:33:47) [GCC 13.2.0]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='a71544e08fb6', release='5.15.0-107-generic', version='#117-Ubuntu SMP Fri Apr 26 12:26:49 UTC 2024', machine='x86_64')
jakevdp commented 2 weeks ago

What is time -f? I can't find any references to it, and when I try it in a bash shell I get /bin/bash: line 1: -f: command not found

Is this somehow printing the memory consumption of the command that follows?

daskol commented 2 weeks ago

Here time is not a part of bash but a standalone executable. Try to install it like apt install time. I use it to report peak (resident) memory usage.

jakevdp commented 2 weeks ago

OK - to clarify, when you mention 10x memory usage, are you comparing the memory usage of the snippet that uses zeros with that of the snippet that uses uniform? If so, that seems like it's likely due uniform being more costly than zeros, not anything related to the jit-compiled function executed later on.

daskol commented 2 weeks ago

Good point. The original issue occurred in situation when I loaded weights of embedding layer xs from a checkpoint. The reason why I include initialization with zeros is that I think that JIT transformation could possibly perform constant elimination and constant folding.

Here is refined test case.

@pytest.mark.parametrize('jitted', [False, True])
@pytest.mark.parametrize('initializer', ['zeros', 'uniform'])
def test_index(initializer: str, jitted: bool):
    ix = jnp.array([1, 2])
    match initializer:
        case 'zeros':
            xs = jnp.zeros(SHAPE)
        case 'uniform':
            key = jax.random.PRNGKey(42)
            xs = jax.random.uniform(key, (3_000_000, 300))
    xs.block_until_ready()

    def fn(ix: jax.Array) -> jax.Array:
        return xs[ix]

    if jitted:
        jax.jit(fn)(ix)
    else:
        fn(ix)

I rerun the test in case of uniform initialization with JIT (True) and without it (False). Also, I have added .block_until_ready() to be sure that weights xs are precomputed. In this case, memory usage differs in 4 times (i.e. 37.2 Gb vs 8.7Gb, respectively). Note that 3_000_000 x 300 array of floats takes ~3.35 Gb, whilst indexing operation under JIT transformation consumes 37.2 Gb in peak (about 10 times more memory). I assume that array elements is in row-major order and slicing or indexing along the first dimension is cheap therefore 30 Gb overhead is too much.

$ /usr/bin/time -f 'RSS: %Mk B' -- pytest -v 'embed_test.py::test_index[uniform-True]'
RSS: 38992700 kB
$ /usr/bin/time -f 'RSS: %M kB' -- pytest -v 'embed_test.py::test_index[uniform-False]'
RSS: 9108472 kB
jakevdp commented 2 weeks ago

Ah, thanks. So the issue here I think is that you've constant-folded a very large array, which leads to it being embedded in an inefficient manner. If you pass the large array as an explicit function argument, it should improve the memory characteristics, because xs is no longer constant-folded in the function.

def fn(ix: jax.Array, xs: jax.Array) -> jax.Array:
  return xs[ix]
daskol commented 2 weeks ago

Good thing there's a workaround. Although, it is a little bit inconvenient, since I use closures as arguments of higher-order functions. But this little alternative with jax.jit seems to do the trick.

def cjit(fun, *args, static_argnums=None, static_argnames=None, **kwargs):
    """Curry and jit a function `fun`."""
    curried_args= args
    curried_kwargs = kwargs
    fn = jax.jit(fun, static_argnums=static_argnums,
                 static_argnames=static_argnames)

    @wraps(fun)
    def inner(*args, **kwargs):
        all_kwargs = {**curried_kwargs, **kwargs}
        return fn(*curried_args, *args, **all_kwargs)

    return inner

So, instead of writing jax.jit(fn)(xs, ix), I do cjit(fn, xs)(ix).