Open man2machine opened 6 days ago
Auto sharding does implement SPMD pipeline parallel if it decides to shard that way.
Thanks @ptoulme-aws for your reply! Does JAX support XLA SPMD auto-sharding? I have looked at the shard_map
function as well as pjit
with auto-arguments, but it is unclear whether any of these implement the SPMP pipeline + shard parallelism.
I found that XLA auto-sharding is based on the Alpa paper https://arxiv.org/abs/2201.12023 which proposed an algorithm for inter-op and intra-op parallelism. However, it appears that it implements only intra-op parallelism (I may be wrong), and not pipeline/inter-op parallelism. Is this true? Does XLA experimental auto-sharding only implement intra-op parallelism? Furthermore, I is this auto-sharding XLA feature available in Jax?