jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.34k stars 2.78k forks source link

Involuntary full rematerialization for advanced data movements #21562

Closed chaserileyroberts closed 3 months ago

chaserileyroberts commented 4 months ago

This code

import jax
import numpy as np
import jax.numpy as jnp
from jax.sharding import PartitionSpec as PS, NamedSharding, Mesh

devices = np.asarray(jax.devices()).reshape((4, 2))
mesh = Mesh(devices, axis_names=('x', 'y'))

shardtype2 = NamedSharding(mesh, PS(None, ('x', 'y'), None))
shardtype1 = NamedSharding(mesh, PS('y', None, 'x'))

def f(a, b, c):
    d = a + b 
    d = jax.lax.with_sharding_constraint(d, shardtype2)
    return c + d 

fjit = jax.jit(f, in_shardings=(shardtype1, shardtype1, shardtype2), out_shardings=shardtype2)

a = jnp.arange((16 * 16 * 16)).reshape((16, 16, 16)).astype(np.float32)
b = jnp.arange((16 * 16 * 16)).reshape((16, 16, 16)).astype(np.float32)
c = jnp.arange((16 * 16 * 16)).reshape((16, 16, 16)).astype(np.float32)

print(fjit(a, b, c).block_until_ready())

Gives this warning

E0531 10:54:04.832741 2609008 spmd_partitioner.cc:569] [spmd] Involuntary full rematerialization. The compiler was not able to go from sharding {devices=[2,1,4]<=[4,2]T(1,0)} to {devices=[1,8,1]<=[8]} without doing a full rematerialization of the tensor. You probably want to enrich the sharding annotations to prevent this from happening.
E0531 10:54:04.832805 2609008 spmd_partitioner.cc:569] [spmd] Involuntary full rematerialization. The compiler was not able to go from sharding {devices=[2,1,4]<=[4,2]T(1,0)} to {devices=[1,8,1]<=[8]} without doing a full rematerialization of the tensor. You probably want to enrich the sharding annotations to prevent this from happening.

Theoretically, any full shardings -> full sharding can be done in a single all-to-all without the need for a rematerialization. Doing a full rematerialization instead of a single all-to-all has obvious implications on performance.

ZixuanJiang commented 4 months ago

Thank you, Chase, for reporting this issue.

It is a known issue that the SPMD partitioner cannot handle this pattern effectively. Instead of a all-to-all from Sharding 1 to Sharding 2, the current solution is Sharding1 -> Replicated Sharding -> Sharding 2. We are actively working to enhance the partitioner.

As a temporary solution, you can enrich the sharding annotations to guide the partitioner.

ZixuanJiang commented 4 months ago

For this specific example, we can add reshape and sharding annotations to guide the partitioner such that it can generate the all-to-all instructions.

shardtype1 = NamedSharding(mesh, PartitionSpec('y', None, 'x'))
shardtype2 = NamedSharding(mesh, PartitionSpec(None, ('x', 'y'), None))

shardtype3 = NamedSharding(mesh, PartitionSpec('y', None, None, 'x'))
shardtype4 = NamedSharding(mesh, PartitionSpec(None, 'x', 'y', None))

def f(a, b, c):
    d = a + b

    d = d.reshape(16, 4, 4, 16)
    d = jax.lax.with_sharding_constraint(d, shardtype3)
    d = jax.lax.with_sharding_constraint(d, shardtype4)
    d = d.reshape(16, 16, 16)

    d = jax.lax.with_sharding_constraint(d, shardtype2)
    return c + d 

devices = np.asarray(jax.devices()).reshape((4, 2))
mesh = Mesh(devices, axis_names=('x', 'y'))

fjit = jax.jit(f, in_shardings=(shardtype1, shardtype1, shardtype2), out_shardings=shardtype2)

a = jnp.arange((16 * 16 * 16)).reshape((16, 16, 16)).astype(np.float32)
b = jnp.arange((16 * 16 * 16)).reshape((16, 16, 16)).astype(np.float32)
c = jnp.arange((16 * 16 * 16)).reshape((16, 16, 16)).astype(np.float32)

print(fjit(a, b, c).block_until_ready())
chaserileyroberts commented 4 months ago

In your example, do you know how many all-to-alls this inserts?

bartchr808 commented 4 months ago

If you do

print(fjit.lower(a, b, c).compile().as_text())

You will get the HLO of the program after partitioning, and you will see 2 all-to-alls

mattjj commented 3 months ago

IIUC this is more of an XLA-level issue than a JAX-level one. Should we keep this issue open, even though (AIUI) we can't make progress on it by working on JAX itself? Or should we e.g. move it to openxla?

chaserileyroberts commented 3 months ago

I think we should likely move it to openxla.

mattjj commented 3 months ago

I'm going to close this thread on the JAX issue tracker (I can't transfer issues between repos, it seems), but @chaserileyroberts if you want to reopen it on OpenXLA please do so!