Open carlosgmartin opened 1 year ago
pjit
or jax.jit
exactly does what you want :)
You write your program as if you are writing for a single device and the XLA GSPMD partitioner should take care of distributing your program and adding collectives for you! See https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html
JAX also has integrated itself with the Alpa compiler wherein you can specify pjit.AUTO
to in_shardings
and out_shardings
and that will invoke the auto-spmd (alpa) compiler pass. You can look at these tests: https://github.com/google/jax/blob/782d90dc8501e7148e1fd1dbd4757b4ca0b3ca4d/tests/pjit_test.py#L1261 to see how to use it.
@yashk2810 The Alpa docs say that pjit lacks some features that alpa.parallelize has ("Pipeline Parallelism" and "Automated"):
In summary, alpa.parallelize supports more parallelism techniques in a more automatic way.
Is this out of date?
Yes probably. JAX supports SPMD pipeline parallelism too. I don't know what automated means.
Yes probably. JAX supports SPMD pipeline parallelism too. I don't know what automated means.
Is there any documentation about pipeline parallelism in JAX? And how does it work? Is it the base pipeline parallelism as GPipe Or interleaved pipeline?
I would also love some guide on how to best use pipeline parallelism in jax. This is especially useful when working with GPUs without fast interconnect, as tensor parallelism (param sharding) is very slow in that case.
After communicating with the people at Alpa, they said that Jax does not welcome them to merge pipeline parallelism function in Alpa into the code of jaxlib. What is this? @yashk2810
Any news? It seems that only Alpa could be able to support pipeline parallelism. @Lime-Cakes
@hawkinsp Is pipeline parallelism possible with current jax version? There are a number of questions regarding this e.g. #17246. To the best of my knowledge. Tensor parallelism works fine but pipeline parallelism does not.
The usual answer appears to be links to Praxis but there is no simple example if it is possible.
In another issue, I wrote the following:
Later, I came across a project that seeks to do just that:
Alpa: Automating Inter- and Intra-Operator Parallelism for Distributed Deep Learning (repo):
I was wondering if there's any interest in adopting/integrating a similar system into JAX itself. This could free the vast majority of JAX users from having to manually tune and restructure their programs to parallelize them efficiently for a particular situation and set of computational resources. It could greatly increase the ease and user-friendliness of distributed computing with JAX, resulting in more widespread adoption and popularity.
I'd love to hear your thoughts and comments.