openxla / xla

A machine learning compiler for GPUs, CPUs, and ML accelerators
Apache License 2.0
2.74k stars 440 forks source link

Fix `PropagateShardingAlongDimsAndReplicateOthers` and expose it as a public util function. #19825

Closed copybara-service[bot] closed 3 days ago

copybara-service[bot] commented 3 days ago

Fix PropagateShardingAlongDimsAndReplicateOthers and expose it as a public util function.

The original description is correct, while the implementation is wrong. Given the following input

source_sharding = {devices=[2,3,5,7,11]<=[2310]}
source_dims = [2, 4, 1]
target_dims = [2, 1, 3]
target_shape_rank = 5

The result shoule be {devices=[1,11,5,3,1,14]<=[2,3,5,7,11]T(4,2,1,0,3) last_tile_dim_replicate}.