By doing sufficiently-exciting capture + sharding' + mapping behavior, it is possible to induce jax's batching to witness inconsistent sizes for the batch axis. The following code snippet is a full repro on Google-internal TPU-backed colab runtimes:
import jax
from jax import numpy as jnp
import numpy as np
from jax.sharding import PartitionSpec as P
from jax.experimental import shard_map
device_mesh = jax.sharding.Mesh(np.array(jax.devices()).reshape([2, 4]), ('x', 'y'))
x_sharding = jax.sharding.NamedSharding(mesh=device_mesh, spec=P('x'))
xy_sharding = jax.sharding.NamedSharding(mesh=device_mesh, spec=P('x', 'y'))
def foo_with_float_arg_no_cond(to_add_slice, global_state):
# We're going to vmap this gut over the shard_map slice of state.
def foo_capturing_something(state_slice):
return to_add_slice + state_slice
vmap_capture = jax.vmap(foo_capturing_something)
# Note check-rep = True or False results in the same error
shmap_vmap_capture = shard_map.shard_map(vmap_capture, mesh=device_mesh, in_specs=P('y'), out_specs=P('y'), check_rep=True)
result = shmap_vmap_capture(global_state)
return result
global_state = jax.device_put(jnp.ones(shape=[2, 4]), xy_sharding)
float_vector = jax.device_put(jnp.array([0.0, 1.0]), x_sharding)
# Works
jax.vmap(foo_with_float_arg_no_cond)(float_vector, global_state)
# Doesn't
jax.vmap(foo_with_float_arg_no_cond, spmd_axis_name='x')(float_vector, global_state)
I'm also sending out a failing test, which will reproduce this problem.
System info (python version, jaxlib version, accelerator, etc.)
Description
By doing sufficiently-exciting capture + sharding' + mapping behavior, it is possible to induce jax's batching to witness inconsistent sizes for the batch axis. The following code snippet is a full repro on Google-internal TPU-backed colab runtimes:
I'm also sending out a failing test, which will reproduce this problem.
System info (python version, jaxlib version, accelerator, etc.)