jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.96k stars 2.75k forks source link

pure_callback is broken with multiple vmap #23624

Open Joshuaalbert opened 6 days ago

Joshuaalbert commented 6 days ago

Description

When vectorized=True the expectation is that the callback of pure_callback should vectorise over common leading batch dims. That is, all batch dims of any mapped array should be identical, with shape broadcasting performed on JAX-side. If an array has not been mapped then it should not receive a batch dim. If this is violated then it is impossible for the callback to construct the proper output shape.

from functools import partial
import jax
import jax.numpy as jnp

@partial(jax.vmap, in_axes=(0, None, None))
@partial(jax.vmap, in_axes=(None, 0, None))
def add_vmapped(x, y, z):
    return x + y + z

@partial(jax.vmap, in_axes=(0, None, None))
@partial(jax.vmap, in_axes=(None, 0, None))
def cb_no_vec(x, y, z):
    def add(x, y, z):
        assert x.shape == ()
        assert y.shape == ()
        assert z.shape == ()
        return x + y + z

    return jax.pure_callback(add, jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype), x, y, z, vectorized=False)

@partial(jax.vmap, in_axes=(0, None, None))
@partial(jax.vmap, in_axes=(None, 0, None))
def cb_vec(x, y, z):
    def add(x, y, z):
        assert x.shape == (4, 5)
        assert y.shape == (4, 5)
        assert z.shape == ()
        return x + y + z

    return jax.pure_callback(add, jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype), x, y, z, vectorized=True)

if __name__ == '__main__':
    x = jnp.arange(4, dtype=jnp.float32)
    y = jnp.arange(5, dtype=jnp.float32)
    z = jnp.array(1, dtype=jnp.float32)

    assert add_vmapped(x, y, z).shape == (4, 5)
    assert cb_no_vec(x, y, z).shape == (4, 5)
    assert cb_vec(x, y, z).shape == (4, 5)

System info (python version, jaxlib version, accelerator, etc.)

jax==0.4.31
jaxlib==0.4.31
Joshuaalbert commented 6 days ago

Similar to https://github.com/google/jax/issues/17187, not sure I follow the logic of this comment

dfm commented 6 days ago

That is, all batch dims of any mapped array should be identical, with shape broadcasting performed on JAX-side.

This actually isn't the behavior of vectorized! I know that the way it's presented in the docs is confusing, and I'm actually pushing to deprecate the vectorized behavior in favor of a more expressive API. I think that what you want is something like a "broadcasting vmap", which can be built using custom_vmap. Something like the following should do the trick:

def broadcasting_vmap(f):
  f = jax.custom_batching.custom_vmap(f)

  @f.def_vmap
  def rule(axis_size, in_batched, *args):
    batched_args = jax.tree.map(
        lambda x, b: x if b else jax.lax.broadcast(x, (axis_size,)), args,
        tuple(in_batched))
    out = f(*batched_args)
    out_batched = jax.tree.map(lambda _: True, out)
    return out, out_batched

  return f
Joshuaalbert commented 6 days ago

It might be just the trick. However, can I suggest you make sure it pass this?

@partial(jax.vmap, in_axes=(0, None, None))
@partial(jax.vmap, in_axes=(None, 0, None))
def cb_vec(x, y, z):
    def add(x, y, z):
        assert x.shape == (4, 5)
        assert y.shape == (4, 5)
        assert z.shape == ()
        return x + y + z
    return jax.pure_callback(add, jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype), x, y, z, vectorized=True)

if __name__ == '__main__':
    x = jnp.arange(4, dtype=jnp.float32)
    y = jnp.arange(5, dtype=jnp.float32)
    z = jnp.array(1, dtype=jnp.float32)

    assert cb_vec(x, y, z).shape == (4, 5)
dfm commented 6 days ago

Are you sure you want assert z.shape == ()? My suggestion was that you write:

@partial(jax.vmap, in_axes=(0, None, None))
@partial(jax.vmap, in_axes=(None, 0, None))
@broadcasting_vmap  # <--------------------------- HERE
def cb_vec(x, y, z):
    def add(x, y, z):
        assert x.shape == (4, 5)
        assert y.shape == (4, 5)
        assert z.shape == (4, 5)
        return x + y + z
    return jax.pure_callback(add, jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype), x, y, z)
Joshuaalbert commented 6 days ago

The problem is that z should be a scalar inside the func, not broadcasted. Note this is not ufunc behaviour but is what I am looking for. Mapped args are broadcasted. Unmapped are not.On Sept 13, 2024 16:21, Dan Foreman-Mackey @.***> wrote: Are you sure you want assert z.shape == ()? My suggestion was that you write: @partial(jax.vmap, in_axes=(0, None, None)) @partial(jax.vmap, in_axes=(None, 0, None)) @broadcasting_vmap # <--------------------------- HERE def cb_vec(x, y, z): def add(x, y, z): assert x.shape == (4, 5) assert y.shape == (4, 5) assert z.shape == (4, 5) return x + y + z return jax.pure_callback(add, jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype), x, y, z)

—Reply to this email directly, view it on GitHub, or unsubscribe.You are receiving this because you authored the thread.Message ID: @.***>

dfm commented 6 days ago

I don't think there's any good way to get that behavior. The inner vmap doesn't "know" about the outer one so I expect you'll be hard pressed to come up with consistent logic to end up with z a scalar. One thing you probably could get would be to get shapes (4, 1), (1, 5), and (1, 1) if that's better for your use case:

A possible implementation ```python def joshuaalbert_vmap(f): f = jax.custom_batching.custom_vmap(f) @f.def_vmap def rule(axis_size, in_batched, *args): batched_args = jax.tree.map( lambda x, b: x if b else jax.lax.broadcast(x, (1,)), args, # <- 1 instead of axis_size tuple(in_batched)) out = f(*batched_args) out_batched = jax.tree.map(lambda _: True, out) return out, out_batched return f @partial(jax.vmap, in_axes=(0, None, None)) @partial(jax.vmap, in_axes=(None, 0, None)) @joshuaalbert_vmap def cb_broadcasting(x, y, z): def add(x, y, z): assert x.shape == (4, 1) assert y.shape == (1, 5) assert z.shape == (1, 1) return x + y + z out_shape = jnp.broadcast_shapes(x.shape, y.shape, z.shape) # <-- note here return jax.pure_callback(add, jax.ShapeDtypeStruct(shape=out_shape, dtype=x.dtype), x, y, z) ```

The issue is that there needs to be some logic for which arguments to broadcast in each vmap and that can't depend on whether or not an argument is going to be mapped in the future. "vectorized" handles this by never mapping anything that isn't mapped, and I think that it's unlikely that we could come up with sensible logic to get exactly what you're asking for here. All that to say, I do think that you might be able to come up with something that works for your use case using custom_vmap and maybe that will help clarifying your feature request.

Joshuaalbert commented 5 days ago

I understand the constraint. Hmm, perhaps there is another middle ground. In principle, if an argument should never be broadcasted, then it can be curried. The remaining args then can receive broadcasting to convert the function into a ufunc style func. I think in effort to make the API clear, you might merge both above broadcast choices, and rename to convert_to_ufunc with a tile boolean which determines if the array shapes should broadcasted beforehand.

def convert_to_ufunc(f, tile: bool = True):
    f = jax.custom_batching.custom_vmap(f)

    @f.def_vmap
    def rule(axis_size, in_batched, *args):
        batched_args = jax.tree.map(
            lambda x, b: x if b else jax.lax.broadcast(x, ((axis_size if tile else 1),)), args,
            tuple(in_batched))
        out = f(*batched_args)
        out_batched = jax.tree.map(lambda _: True, out)
        return out, out_batched

    return f

def cb(x, y, z):
    def add(x, y, z):
        assert x.shape == (4, 5) # if tile=False
        assert y.shape == (4, 5) # if tile=False
        assert z.shape == ()
        return x + y + z

    return jax.pure_callback(add, jax.ShapeDtypeStruct(shape=jnp.broadcast_shapes(x.shape, y.shape), dtype=x.dtype), x,
                             y, z, vectorized=True)

# Curry z first
assert jax.vmap(jax.vmap(convert_to_ufunc(partial(cb, z=z)), in_axes=(None, 0)), in_axes=(0, None))(x, y).shape == (4, 5)

With this setup the original intent of this issue is resolved, i.e. we can now trust that applying vmap multiple times gives consistent shapes inside the callback, which allows easier reasoning.