Open Joshuaalbert opened 2 months ago
Similar to https://github.com/google/jax/issues/17187, not sure I follow the logic of this comment
That is, all batch dims of any mapped array should be identical, with shape broadcasting performed on JAX-side.
This actually isn't the behavior of vectorized
! I know that the way it's presented in the docs is confusing, and I'm actually pushing to deprecate the vectorized
behavior in favor of a more expressive API. I think that what you want is something like a "broadcasting vmap", which can be built using custom_vmap
. Something like the following should do the trick:
def broadcasting_vmap(f):
f = jax.custom_batching.custom_vmap(f)
@f.def_vmap
def rule(axis_size, in_batched, *args):
batched_args = jax.tree.map(
lambda x, b: x if b else jax.lax.broadcast(x, (axis_size,)), args,
tuple(in_batched))
out = f(*batched_args)
out_batched = jax.tree.map(lambda _: True, out)
return out, out_batched
return f
It might be just the trick. However, can I suggest you make sure it pass this?
@partial(jax.vmap, in_axes=(0, None, None))
@partial(jax.vmap, in_axes=(None, 0, None))
def cb_vec(x, y, z):
def add(x, y, z):
assert x.shape == (4, 5)
assert y.shape == (4, 5)
assert z.shape == ()
return x + y + z
return jax.pure_callback(add, jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype), x, y, z, vectorized=True)
if __name__ == '__main__':
x = jnp.arange(4, dtype=jnp.float32)
y = jnp.arange(5, dtype=jnp.float32)
z = jnp.array(1, dtype=jnp.float32)
assert cb_vec(x, y, z).shape == (4, 5)
Are you sure you want assert z.shape == ()
? My suggestion was that you write:
@partial(jax.vmap, in_axes=(0, None, None))
@partial(jax.vmap, in_axes=(None, 0, None))
@broadcasting_vmap # <--------------------------- HERE
def cb_vec(x, y, z):
def add(x, y, z):
assert x.shape == (4, 5)
assert y.shape == (4, 5)
assert z.shape == (4, 5)
return x + y + z
return jax.pure_callback(add, jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype), x, y, z)
The problem is that z should be a scalar inside the func, not broadcasted. Note this is not ufunc behaviour but is what I am looking for. Mapped args are broadcasted. Unmapped are not.On Sept 13, 2024 16:21, Dan Foreman-Mackey @.***> wrote: Are you sure you want assert z.shape == ()? My suggestion was that you write: @partial(jax.vmap, in_axes=(0, None, None)) @partial(jax.vmap, in_axes=(None, 0, None)) @broadcasting_vmap # <--------------------------- HERE def cb_vec(x, y, z): def add(x, y, z): assert x.shape == (4, 5) assert y.shape == (4, 5) assert z.shape == (4, 5) return x + y + z return jax.pure_callback(add, jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype), x, y, z)
—Reply to this email directly, view it on GitHub, or unsubscribe.You are receiving this because you authored the thread.Message ID: @.***>
I don't think there's any good way to get that behavior. The inner vmap doesn't "know" about the outer one so I expect you'll be hard pressed to come up with consistent logic to end up with z a scalar. One thing you probably could get would be to get shapes (4, 1)
, (1, 5)
, and (1, 1)
if that's better for your use case:
The issue is that there needs to be some logic for which arguments to broadcast in each vmap and that can't depend on whether or not an argument is going to be mapped in the future. "vectorized" handles this by never mapping anything that isn't mapped, and I think that it's unlikely that we could come up with sensible logic to get exactly what you're asking for here. All that to say, I do think that you might be able to come up with something that works for your use case using custom_vmap
and maybe that will help clarifying your feature request.
I understand the constraint. Hmm, perhaps there is another middle ground. In principle, if an argument should never be broadcasted, then it can be curried. The remaining args then can receive broadcasting to convert the function into a ufunc style func. I think in effort to make the API clear, you might merge both above broadcast choices, and rename to convert_to_ufunc
with a tile
boolean which determines if the array shapes should broadcasted beforehand.
def convert_to_ufunc(f, tile: bool = True):
f = jax.custom_batching.custom_vmap(f)
@f.def_vmap
def rule(axis_size, in_batched, *args):
batched_args = jax.tree.map(
lambda x, b: x if b else jax.lax.broadcast(x, ((axis_size if tile else 1),)), args,
tuple(in_batched))
out = f(*batched_args)
out_batched = jax.tree.map(lambda _: True, out)
return out, out_batched
return f
def cb(x, y, z):
def add(x, y, z):
assert x.shape == (4, 5) # if tile=False
assert y.shape == (4, 5) # if tile=False
assert z.shape == ()
return x + y + z
return jax.pure_callback(add, jax.ShapeDtypeStruct(shape=jnp.broadcast_shapes(x.shape, y.shape), dtype=x.dtype), x,
y, z, vectorized=True)
# Curry z first
assert jax.vmap(jax.vmap(convert_to_ufunc(partial(cb, z=z)), in_axes=(None, 0)), in_axes=(0, None))(x, y).shape == (4, 5)
With this setup the original intent of this issue is resolved, i.e. we can now trust that applying vmap multiple times gives consistent shapes inside the callback, which allows easier reasoning.
Description
When
vectorized=True
the expectation is that thecallback
ofpure_callback
should vectorise over common leading batch dims. That is, all batch dims of any mapped array should be identical, with shape broadcasting performed on JAX-side. If an array has not been mapped then it should not receive a batch dim. If this is violated then it is impossible for thecallback
to construct the proper output shape.System info (python version, jaxlib version, accelerator, etc.)