When I run the following code block, it produces different output with and without sharding:
import jax
import jax.numpy as jnp
import numpy as np
mesh = jax.sharding.Mesh(np.array(jax.devices()), ('x', ))
spec = jax.sharding.PartitionSpec('x',)
# A = jnp.ones((4,32)) # if we use this shape, then the sharded output is correct
A = jnp.ones((4,1024))
print("without sharding:")
print(A.at[::2].set(0))
with mesh:
B = jax.lax.with_sharding_constraint(A, spec)
print("with sharding:")
print(B.at[::2].set(0))
I observed similar problems with different shaped array on different TPU topologies, but as can be seen from the commented out line the problem is array shape dependent.
System info (python version, jaxlib version, accelerator, etc.)
I believe I'm seeing things working at the latest stable release point (jax+jaxlib 0.4.31, released about three days ago). Do you still see the issue if you upgrade?
Description
When I run the following code block, it produces different output with and without sharding:
The output of this code block is:
I observed similar problems with different shaped array on different TPU topologies, but as can be seen from the commented out line the problem is array shape dependent.
System info (python version, jaxlib version, accelerator, etc.)