Closed chaserileyroberts closed 3 months ago
Thank you, Chase, for reporting this issue.
It is a known issue that the SPMD partitioner cannot handle this pattern effectively. Instead of a all-to-all from Sharding 1 to Sharding 2, the current solution is Sharding1 -> Replicated Sharding -> Sharding 2. We are actively working to enhance the partitioner.
As a temporary solution, you can enrich the sharding annotations to guide the partitioner.
For this specific example, we can add reshape and sharding annotations to guide the partitioner such that it can generate the all-to-all instructions.
shardtype1 = NamedSharding(mesh, PartitionSpec('y', None, 'x'))
shardtype2 = NamedSharding(mesh, PartitionSpec(None, ('x', 'y'), None))
shardtype3 = NamedSharding(mesh, PartitionSpec('y', None, None, 'x'))
shardtype4 = NamedSharding(mesh, PartitionSpec(None, 'x', 'y', None))
def f(a, b, c):
d = a + b
d = d.reshape(16, 4, 4, 16)
d = jax.lax.with_sharding_constraint(d, shardtype3)
d = jax.lax.with_sharding_constraint(d, shardtype4)
d = d.reshape(16, 16, 16)
d = jax.lax.with_sharding_constraint(d, shardtype2)
return c + d
devices = np.asarray(jax.devices()).reshape((4, 2))
mesh = Mesh(devices, axis_names=('x', 'y'))
fjit = jax.jit(f, in_shardings=(shardtype1, shardtype1, shardtype2), out_shardings=shardtype2)
a = jnp.arange((16 * 16 * 16)).reshape((16, 16, 16)).astype(np.float32)
b = jnp.arange((16 * 16 * 16)).reshape((16, 16, 16)).astype(np.float32)
c = jnp.arange((16 * 16 * 16)).reshape((16, 16, 16)).astype(np.float32)
print(fjit(a, b, c).block_until_ready())
In your example, do you know how many all-to-alls this inserts?
If you do
print(fjit.lower(a, b, c).compile().as_text())
You will get the HLO of the program after partitioning, and you will see 2 all-to-alls
IIUC this is more of an XLA-level issue than a JAX-level one. Should we keep this issue open, even though (AIUI) we can't make progress on it by working on JAX itself? Or should we e.g. move it to openxla?
I think we should likely move it to openxla.
I'm going to close this thread on the JAX issue tracker (I can't transfer issues between repos, it seems), but @chaserileyroberts if you want to reopen it on OpenXLA please do so!
This code
Gives this warning
Theoretically, any full shardings -> full sharding can be done in a single all-to-all without the need for a rematerialization. Doing a full rematerialization instead of a single all-to-all has obvious implications on performance.