jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.44k stars 2.8k forks source link

Fully automatic parallelization of entire JAX programs #15710

Open carlosgmartin opened 1 year ago

carlosgmartin commented 1 year ago

In another issue, I wrote the following:

I dream of a compiler that is powerful enough to let users focus solely on the semantics of a program (what is to be computed), while the compiler figures out how to distribute the computation efficiently over a set of available resources (how it is to be computed). So no more pmap, xmap, shmap, pjit, pmean, etc. Just write your program as if it ran on a single device (i.e., specify its semantics), and let the compiler figure out the rest. As the shmap page says: Compiler, take the wheel!

Later, I came across a project that seeks to do just that:

Alpa: Automating Inter- and Intra-Operator Parallelism for Distributed Deep Learning (repo):

Alpa is built on top of a tensor computation framework Jax. Alpa can automatically parallelize jax functions and runs them on a distributed cluster. Alpa analyses the computational graph and generates a distributed execution plan tailored for the computational graph and target cluster. The generated execution plan can combine state-of-the-art distributed training techniques including data parallelism, operator parallelism, and pipeline parallelism.

Alpa provides a simple API alpa.parallelize and automatically generates the best execution plan by solving optimization problems. Therefore, you can efficiently scale your jax computation on a distributed cluster, without any expertise in distributed computing.

Alpa provides a transformation alpa.parallelize to parallelize a jax function. alpa.parallelize is similar to jax.jit. jax.jit compiles a jax function for a single device, while alpa.parallelize compiles a jax function for a distributed device cluster. You may know that jax has some built-in transformations for parallelization, such as pmap, pjit, and xmap. However, these transformations are not fully automatic, be cause they require users to manually specify the parallelization strategies such as parallelization axes and device mapping schemes. You also need to manually call communication primitives such as lax.pmean and lax.all_gather, which is nontrivial if you want to do advanced model parallelization. Unlike these transformations, alpa.parallelize can do all things automatically for you. alpa.parallelize finds the best parallelization strategy for the given jax function and does the code tranformation. You only need to write the code as if you are writing for a single device.

I was wondering if there's any interest in adopting/integrating a similar system into JAX itself. This could free the vast majority of JAX users from having to manually tune and restructure their programs to parallelize them efficiently for a particular situation and set of computational resources. It could greatly increase the ease and user-friendliness of distributed computing with JAX, resulting in more widespread adoption and popularity.

I'd love to hear your thoughts and comments.

yashk2810 commented 1 year ago

pjit or jax.jit exactly does what you want :)

You write your program as if you are writing for a single device and the XLA GSPMD partitioner should take care of distributing your program and adding collectives for you! See https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html

JAX also has integrated itself with the Alpa compiler wherein you can specify pjit.AUTO to in_shardings and out_shardings and that will invoke the auto-spmd (alpa) compiler pass. You can look at these tests: https://github.com/google/jax/blob/782d90dc8501e7148e1fd1dbd4757b4ca0b3ca4d/tests/pjit_test.py#L1261 to see how to use it.

carlosgmartin commented 1 year ago

@yashk2810 The Alpa docs say that pjit lacks some features that alpa.parallelize has ("Pipeline Parallelism" and "Automated"):

In summary, alpa.parallelize supports more parallelism techniques in a more automatic way.

Is this out of date?

yashk2810 commented 1 year ago

Yes probably. JAX supports SPMD pipeline parallelism too. I don't know what automated means.

Lime-Cakes commented 1 year ago

Yes probably. JAX supports SPMD pipeline parallelism too. I don't know what automated means.

Is there any documentation about pipeline parallelism in JAX? And how does it work? Is it the base pipeline parallelism as GPipe gpipe Or interleaved pipeline? interleaved

maxidl commented 1 year ago

I would also love some guide on how to best use pipeline parallelism in jax. This is especially useful when working with GPUs without fast interconnect, as tensor parallelism (param sharding) is very slow in that case.

MoFHeka commented 1 year ago

After communicating with the people at Alpa, they said that Jax does not welcome them to merge pipeline parallelism function in Alpa into the code of jaxlib. What is this? @yashk2810

MoFHeka commented 1 year ago

Any news? It seems that only Alpa could be able to support pipeline parallelism. @Lime-Cakes

aniquetahir commented 1 year ago

@hawkinsp Is pipeline parallelism possible with current jax version? There are a number of questions regarding this e.g. #17246. To the best of my knowledge. Tensor parallelism works fine but pipeline parallelism does not.

The usual answer appears to be links to Praxis but there is no simple example if it is possible.