Open chaserileyroberts opened 4 months ago
This means there was some conflict in sharding - potentially sharding that was propagated from an intermediate. The XLA compiler had to rematerialize the full tensor to reshard it.
This PR improves logging to show you which Hlo Instruction has the conflict https://github.com/openxla/xla/pull/15402
If you want to further debug this, look at the Hlo dump after ShardingPropagation pass but before SPMD partitioning. Then using my logging PR look at the HloSharding metadata of that HloInstruction and the instructions around it. Most likely you will see there is a conflict like (4,8)->(8,4) triggered the reshard.
Ported this issue from https://github.com/google/jax/issues/21562
This code
Gives this warning
We've had to write our own resharding logic instead of solely relying on
with_sharding_constraint
to avoid this issue.