openxla / xla

A machine learning compiler for GPUs, CPUs, and ML accelerators
Apache License 2.0
2.55k stars 394 forks source link

[XLA::SPMD] Propagating shardings between operands of element-wise operations. #14887

Closed copybara-service[bot] closed 1 month ago

copybara-service[bot] commented 1 month ago

[XLA::SPMD] Propagating shardings between operands of element-wise operations.

Let us use C = foo(A, B) as an example. We have two key steps in sharding propagation.

  1. Forward propagation. Infer C's sharding from A and B.
  2. Backward propagation. Infer A's or B's sharding from C.

Ideally, we do not need to explicitly propagate shardings between operands since we have the paths A -> C -> B and B -> C -> A. However, this path may be unavailable if C has a pre-defined sharding, which impedes the forward propagation. To resolve this issue, we may add the propagation A -> B and B -> A directly.

It was already considered for several operations, such as dot and reduce. This cl support it for element-wise operations.