jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.32k stars 2.78k forks source link

Functionality to chunk `vmap`. #11319

Open joeryjoery opened 2 years ago

joeryjoery commented 2 years ago

Occasionally I run into the problem that batch-sizes are too large for GPU memory and using the current public API of Jax I either have to commit to one of the extremes of using vmap or using the much more limited scan. (I'm still relatively new to Jax, so please correct me if I'm wrong). Could there perhaps be a chunk argument to the vmap function that limits how much is passed to the vmapped function at any time?

I tried implementing this, and have the following mock-up. At the moment this is based on a num_chunks parameter, but this could be more flexible such as in terms of chunk_size. It splits or repeats the inputs of the vmapped function based on the given in_axes, pads all input arrays uniformly with zeroes along the batch dimension, and finally for-loop the chunks through a conventional vmap function. The results are finally concatenated based on their canonical slices.

I don't have all the functionalities of the normal vmap implemented, but it works in a slightly more narrow scope. Could something like this be supported by the public Jax API?

def pad_along_axis(array: jnp.ndarray, axis_length: int, axis: int = 0, *args, **kwargs) -> jnp.ndarray:

    target_size = axis_length - jnp.shape(array)[axis]

    padding = [(0, 0)] * jnp.ndim(array)
    padding[axis] = (0, target_size)

    return jnp.pad(array, padding, *args, **kwargs)

def chunked_vmap(fun, num_chunks: int = 1, in_axes=0, out_axes=0, axis_name=None, axis_size=None):
    # TODO: Compatibility on flattened in_axes. Implementation for out_axes, axis_name, axis_size.

    # Note, num_chunks == 1 is equivalent to just using `vmap_fun`.
    vmap_fun = jax.vmap(fun, in_axes, out_axes, axis_name, axis_size)

    # Leaf structure of input splitting: ([chunk_a, chunk_b, ...], [pad_a, pad_b, ...]) 
    splitted_treedef = jax.tree_structure(([1] * num_chunks,) * 2)

    def split_fun(arg, ax):  
        # Operates on pytree leaves.

        if ax is None:
            return [arg] * num_chunks, [0] * num_chunks

        chunks = jnp.array_split(arg, num_chunks, axis=ax)

        leading_size = jnp.shape(chunks[0])[ax]
        batch_dims = jax.tree_map(lambda a: jnp.shape(a)[ax], chunks)

        padded_chunks = jax.tree_map(partial(pad_along_axis, axis_length=leading_size, axis=ax), chunks)

        return padded_chunks, batch_dims

    def vmap_f(*args, **kwargs):  # TODO: Incorporate kwargs?
        splitted = jax.tree_map(split_fun, args, in_axes)

        input_chunks, canonical_sizes = jax.tree_transpose(
            jax.tree_structure(args), splitted_treedef, splitted
        )
        out_sizes = [max(jax.tree_leaves(s)) for s in canonical_sizes]

        # TODO: use jax.lax.scan? Note the dynamic shapes of jax.lax.slice and that in_axes is not yet supported.
        results = [jax.lax.slice(vmap_fun(*c), (0, ), (s,)) for c, s in zip(input_chunks, out_sizes)]

        # TODO: collect all outputs immediately, or use a generator with `yield`?
        out = jax.tree_map(lambda *a: jnp.concatenate(a), *results)
        return out

    return vmap_f
def myfun(a, b, c):
    return jnp.square(a) * c + jnp.squeeze(b['val'])

v = jnp.arange(100)

args = (v, {'val': v.reshape(1, 1, -1, 100)}, 0.4)
in_axes = (0, {'val': -1}, None)

vmap_fun = jax.vmap(myfun, in_axes=in_axes)
out = vmap_fun(*args)

for chunk_size in [1, 2, 5, 10, 50]:
    chunk_out = chunked_vmap(myfun, chunk_size, in_axes=in_axes)(*args)

    assert jnp.isclose(out, chunk_out).all()  # Runs fine
shoyer commented 2 years ago

Take a look at jax.experimental.maps.xmap, which is designed for exactly this sort of thing: https://jax.readthedocs.io/en/latest/notebooks/xmap_tutorial.html https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.maps.xmap.html

You can chunk vmap by making use of SerialLoop.

ynotzort commented 1 year ago

This here might be something you could use: https://netket.readthedocs.io/en/stable/api/_generated/jax/netket.jax.vmap_chunked.html#netket.jax.vmap_chunked

joeryjoery commented 1 year ago

This here might be something you could use: https://netket.readthedocs.io/en/stable/api/_generated/jax/netket.jax.vmap_chunked.html#netket.jax.vmap_chunked

Thanks that is an excellent reference. I am of the opinion though that Jax could implement this by default given how generally useful this is.

joeryjoery commented 11 months ago

Hi, more than a year later ;P, are there any plans on implementing this by default in Jax? Or would the authors be open for a PR?

shoyer commented 11 months ago

@mattjj we were just talking about this :)

PhilipVinc commented 11 months ago

If you do get to it, I would love if you also add support for vjp over 'chunked' axes like we did in https://netket.readthedocs.io/en/stable/api/_generated/jax/netket.jax.vjp_chunked.html

f0uriest commented 11 months ago

+1 for adding this feature, it would be super useful in several projects for me.

froystig commented 11 months ago

Discussion https://github.com/google/jax/discussions/18398 asks for this as well.

XLY43 commented 11 months ago

+1 upvote for implementing this

carlosgmartin commented 3 months ago

It would be nice if JAX's compiler could automatically convert computations from parallel to sequential when necessary, given a known memory constraint. Has this possibility been discussed anywhere?

cgarciae commented 3 months ago

jax.lax.map now has a batch_size argument that will chuck the computation and internally utilize vmap to operate over each batch in parallel. See #19614.

arunoruto commented 1 month ago

I tried using JAX as an alternative to numba, since it is much easier to work in, and I can switch between CPU and GPU using a flag. The biggest problem I am facing now, is that vmap tries to put all of the data onto the GPU at the same time. I am already using vmap to apply the scalar function onto a 2D array, so I was wondering if and how map could help me to batch the data, so I am not getting of memory errors!

froystig commented 1 month ago

Try the batch_size argument to jax.lax.map? (Added in #19614.)

Think of it as a sequential loop over batches, where each batch is vmapped.