openxla / xla

A machine learning compiler for GPUs, CPUs, and ML accelerators
Apache License 2.0
2.55k stars 394 forks source link

[XLA:SPMD] Optimize the partitioning for element-wise operations when all operands share the same sharding. #14873

Open copybara-service[bot] opened 1 month ago

copybara-service[bot] commented 1 month ago

[XLA:SPMD] Optimize the partitioning for element-wise operations when all operands share the same sharding.

Let us take C with S3 = add(A with S1, B with S2) as an example, where A, B, C are tensors, S1, S2, S3 are their shardings.

Before this change, we always have

A with S3 = reshard(A with S1, new_sharding=S3)
B with S3 = reshard(B with S2, new_sharding=S3)
C with S3 = add(A with S3, B with S3)

With this cl, if S1 and S2 are the same, we will have

C with S1 = add(A with S1, B with S1)
C with S3 = reshard(C with S1)

The new partitioning method can reduce the number of reshards.