openxla / xla

A machine learning compiler for GPUs, CPUs, and ML accelerators
Apache License 2.0
2.75k stars 441 forks source link

Involuntary Full Rematerialization #15207

Open chaserileyroberts opened 4 months ago

chaserileyroberts commented 4 months ago

Ported this issue from https://github.com/google/jax/issues/21562

This code

import jax
import numpy as np
import jax.numpy as jnp
from jax.sharding import PartitionSpec as PS, NamedSharding, Mesh

devices = np.asarray(jax.devices()).reshape((4, 2))
mesh = Mesh(devices, axis_names=('x', 'y'))

shardtype2 = NamedSharding(mesh, PS(None, ('x', 'y'), None))
shardtype1 = NamedSharding(mesh, PS('y', None, 'x'))

def f(a, b, c):
    d = a + b 
    d = jax.lax.with_sharding_constraint(d, shardtype2)
    return c + d 

fjit = jax.jit(f, in_shardings=(shardtype1, shardtype1, shardtype2), out_shardings=shardtype2)

a = jnp.arange((16 * 16 * 16)).reshape((16, 16, 16)).astype(np.float32)
b = jnp.arange((16 * 16 * 16)).reshape((16, 16, 16)).astype(np.float32)
c = jnp.arange((16 * 16 * 16)).reshape((16, 16, 16)).astype(np.float32)

print(fjit(a, b, c).block_until_ready())

Gives this warning

E0531 10:54:04.832741 2609008 spmd_partitioner.cc:569] [spmd] Involuntary full rematerialization. The compiler was not able to go from sharding {devices=[2,1,4]<=[4,2]T(1,0)} to {devices=[1,8,1]<=[8]} without doing a full rematerialization of the tensor. You probably want to enrich the sharding annotations to prevent this from happening.
E0531 10:54:04.832805 2609008 spmd_partitioner.cc:569] [spmd] Involuntary full rematerialization. The compiler was not able to go from sharding {devices=[2,1,4]<=[4,2]T(1,0)} to {devices=[1,8,1]<=[8]} without doing a full rematerialization of the tensor. You probably want to enrich the sharding annotations to prevent this from happening.

We've had to write our own resharding logic instead of solely relying on with_sharding_constraint to avoid this issue.

ptoulme-aws commented 4 months ago

This means there was some conflict in sharding - potentially sharding that was propagated from an intermediate. The XLA compiler had to rematerialize the full tensor to reshard it.

This PR improves logging to show you which Hlo Instruction has the conflict https://github.com/openxla/xla/pull/15402

ptoulme-aws commented 4 months ago

If you want to further debug this, look at the Hlo dump after ShardingPropagation pass but before SPMD partitioning. Then using my logging PR look at the HloSharding metadata of that HloInstruction and the instructions around it. Most likely you will see there is a conflict like (4,8)->(8,4) triggered the reshard.