jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.73k stars 2.84k forks source link

Improve documentation for multi-node/host training #20099

Open neel04 opened 9 months ago

neel04 commented 9 months ago

As laid out in #20053, I feel that there is a need to have a minimal example for how to use the new sharding API with device_put for training in a multi-process fashion, like a TPU Pod slice.

So far, I've seen multiple different methods to accomplish the same which leverage ideas (like pjit) that are out of date and some that straight up don't work.

For example, I was following this discussion where apparently make_array_from_single_device_arrays wasn't returning a sharded array for the host - but rather the global array which was n times as big (n is the number of hosts).

This is a major pain point, as the API around parallelization should be top-notch considering scalability is a strong focus for the jax team. Rather, in practice the end user ends up wading through a sea of a variety of xmap/pmap/sharding and intermixing of both, some of which are deprecated and not recommended to be used.

This seriously needs to be improved. I feel the sharding API is really well written for parallelizing across multiple-devices on a single host, and it only took me 15 mins to integrate that. However, multi-node/process definitely needs a minimal example.

ChenAo-Phys commented 6 months ago

I can't agree more! I'm also trying to write some multi-node codes, but it's confusing when I want to stick to the style of automatic parallelization with multiple hosts.

A minimum example will be extremely helpful. At least it can tell us what is the recommended way to implement parallelization. For now I can also spend a month implementing make_array_from_single_device_arrays everywhere in my codes, but I'm afraid it may be deprecated in the future, in which case my effort will be fully wasted.

yashk2810 commented 6 months ago

You can also use jax.make_array_from_process_local_data to create global jax.Array from data on your hosts.

neel04 commented 6 months ago

FWIW, I use scalax for multinode currently. however it does have a bit of a learning curve, but its pretty easy once you do figure it out.

Rather disappointing that 3rd party libs have to put up support for something so fundamental as multi-node parallelization. I feel it goes against the jax spirit.