openxla / xla

A machine learning compiler for GPUs, CPUs, and ML accelerators
Apache License 2.0
2.66k stars 427 forks source link

SPMD - Handle Gather/Scatter is Harcoded to IndexParallel #13304

Closed ptoulme-aws closed 3 months ago

ptoulme-aws commented 4 months ago

In the SPMD partitioner the gather/scatter handler hardcodes the cost to {0,0} for IndexParallel strategy. This is not ideal for all hardware, especially for hardware with 2D Torus topology. This should be refactored to allow the strategy to hardcode to be passed as config.

Location of hardcode cost - https://github.com/openxla/xla/blob/eaed933666ca4b44ea96b6bdae13631c1edfea00/xla/service/spmd/gather_scatter_handler.cc#L736

cheshire commented 4 months ago

Thanks ! Do you plan to provide a patch for this?

ptoulme-aws commented 4 months ago

I can submit PR that moves this to behind an optional config? Default will still be IndexParallel to maintain existing flow

frgossen commented 4 months ago

I think making these things configurable makes sense.