Open neel04 opened 9 months ago
I can't agree more! I'm also trying to write some multi-node codes, but it's confusing when I want to stick to the style of automatic parallelization with multiple hosts.
A minimum example will be extremely helpful. At least it can tell us what is the recommended way to implement parallelization. For now I can also spend a month implementing make_array_from_single_device_arrays
everywhere in my codes, but I'm afraid it may be deprecated in the future, in which case my effort will be fully wasted.
You can also use jax.make_array_from_process_local_data
to create global jax.Array from data on your hosts.
FWIW, I use scalax for multinode currently. however it does have a bit of a learning curve, but its pretty easy once you do figure it out.
Rather disappointing that 3rd party libs have to put up support for something so fundamental as multi-node parallelization. I feel it goes against the jax spirit.
As laid out in #20053, I feel that there is a need to have a minimal example for how to use the new sharding API with
device_put
for training in a multi-process fashion, like a TPU Pod slice.So far, I've seen multiple different methods to accomplish the same which leverage ideas (like
pjit
) that are out of date and some that straight up don't work.For example, I was following this discussion where apparently
make_array_from_single_device_arrays
wasn't returning a sharded array for the host - but rather the global array which wasn
times as big (n
is the number of hosts).This is a major pain point, as the API around parallelization should be top-notch considering scalability is a strong focus for the jax team. Rather, in practice the end user ends up wading through a sea of a variety of
xmap
/pmap
/sharding
and intermixing of both, some of which are deprecated and not recommended to be used.This seriously needs to be improved. I feel the sharding API is really well written for parallelizing across multiple-devices on a single host, and it only took me 15 mins to integrate that. However, multi-node/process definitely needs a minimal example.