Open daskol opened 2 weeks ago
What is time -f
? I can't find any references to it, and when I try it in a bash shell I get /bin/bash: line 1: -f: command not found
Is this somehow printing the memory consumption of the command that follows?
Here time
is not a part of bash but a standalone executable. Try to install it like apt install time
. I use it to report peak (resident) memory usage.
OK - to clarify, when you mention 10x memory usage, are you comparing the memory usage of the snippet that uses zeros
with that of the snippet that uses uniform
? If so, that seems like it's likely due uniform
being more costly than zeros
, not anything related to the jit-compiled function executed later on.
Good point. The original issue occurred in situation when I loaded weights of embedding layer xs
from a checkpoint. The reason why I include initialization with zeros is that I think that JIT transformation could possibly perform constant elimination and constant folding.
Here is refined test case.
@pytest.mark.parametrize('jitted', [False, True])
@pytest.mark.parametrize('initializer', ['zeros', 'uniform'])
def test_index(initializer: str, jitted: bool):
ix = jnp.array([1, 2])
match initializer:
case 'zeros':
xs = jnp.zeros(SHAPE)
case 'uniform':
key = jax.random.PRNGKey(42)
xs = jax.random.uniform(key, (3_000_000, 300))
xs.block_until_ready()
def fn(ix: jax.Array) -> jax.Array:
return xs[ix]
if jitted:
jax.jit(fn)(ix)
else:
fn(ix)
I rerun the test in case of uniform
initialization with JIT (True
) and without it (False
). Also, I have added .block_until_ready()
to be sure that weights xs
are precomputed. In this case, memory usage differs in 4 times (i.e. 37.2 Gb vs 8.7Gb, respectively). Note that 3_000_000 x 300
array of floats takes ~3.35 Gb, whilst indexing operation under JIT transformation consumes 37.2 Gb in peak (about 10 times more memory). I assume that array elements is in row-major order and slicing or indexing along the first dimension is cheap therefore 30 Gb overhead is too much.
$ /usr/bin/time -f 'RSS: %Mk B' -- pytest -v 'embed_test.py::test_index[uniform-True]'
RSS: 38992700 kB
$ /usr/bin/time -f 'RSS: %M kB' -- pytest -v 'embed_test.py::test_index[uniform-False]'
RSS: 9108472 kB
Ah, thanks. So the issue here I think is that you've constant-folded a very large array, which leads to it being embedded in an inefficient manner. If you pass the large array as an explicit function argument, it should improve the memory characteristics, because xs
is no longer constant-folded in the function.
def fn(ix: jax.Array, xs: jax.Array) -> jax.Array:
return xs[ix]
Good thing there's a workaround. Although, it is a little bit inconvenient, since I use closures as arguments of higher-order functions. But this little alternative with jax.jit
seems to do the trick.
def cjit(fun, *args, static_argnums=None, static_argnames=None, **kwargs):
"""Curry and jit a function `fun`."""
curried_args= args
curried_kwargs = kwargs
fn = jax.jit(fun, static_argnums=static_argnums,
static_argnames=static_argnames)
@wraps(fun)
def inner(*args, **kwargs):
all_kwargs = {**curried_kwargs, **kwargs}
return fn(*curried_args, *args, **all_kwargs)
return inner
So, instead of writing jax.jit(fn)(xs, ix)
, I do cjit(fn, xs)(ix)
.
Description
I'm working with large embedding layers like 3'000'000x300 or even bigger on CPU-only devices. I noticed that
jax
takes much more memory than expected.With the code snippet above, one can reproduce the issues as folows.
One can see that memory usage differs in about 10 times while jitted function essentially gets a single item from a 2D array. This operation is common in embedding layers.
System info (python version, jaxlib version, accelerator, etc.)
Issue is reproduced across two jax versions on two different hosts.