openxla / xla

A machine learning compiler for GPUs, CPUs, and ML accelerators
Apache License 2.0
2.7k stars 432 forks source link

Does XLA auto-sharding implement inter-op parallelism? #19103

Open man2machine opened 6 days ago

man2machine commented 6 days ago

I found that XLA auto-sharding is based on the Alpa paper https://arxiv.org/abs/2201.12023 which proposed an algorithm for inter-op and intra-op parallelism. However, it appears that it implements only intra-op parallelism (I may be wrong), and not pipeline/inter-op parallelism. Is this true? Does XLA experimental auto-sharding only implement intra-op parallelism? Furthermore, I is this auto-sharding XLA feature available in Jax?

ptoulme-aws commented 6 days ago

Auto sharding does implement SPMD pipeline parallel if it decides to shard that way.

man2machine commented 4 days ago

Thanks @ptoulme-aws for your reply! Does JAX support XLA SPMD auto-sharding? I have looked at the shard_map function as well as pjit with auto-arguments, but it is unclear whether any of these implement the SPMP pipeline + shard parallelism.

ptoulme-aws commented 2 days ago

https://github.com/jax-ml/jax/blob/098d582e70dd1b5fe469fb9d808ae02e8b2ae809/jax/_src/compiler.py#L117