Open Conchylicultor opened 9 months ago
I think we can introduce some wrappers to create arrays like that.
Something like: jax.make_array_from_data_parallel_config(local_shape, mesh, np_array)
which would basically do what the docstring does.
You do need devices to device_put so just specifying jax.sharding.REPLICATED is not enough. But a simple wrapper around it should work out. WDYT?
Thanks you for the answer!
Something like: jax.make_array_from_data_parallel_config(local_shape, mesh, np_array) which would basically do what the docstring does.
Without seeing the end-to-end code (with sharding, local_shape,... definition), it's difficult to compare. I'm not sure this would fix the issue:
local_shape
and mesh
.jax.make_array_from_single_device_arrays(
usage non trivial)jax.device_put
and those jax.make_array_from_data_parallel_config
. To keep API minimal, why not only keep jax.device_put
for most usage.You do need devices to device_put so just specifying jax.sharding.REPLICATED is not enough. But a simple wrapper around it should work out. WDYT?
We implemented this API in our codebase, so it is possible: http://google3/third_party/py/kauldron/evals/evaluators.py;l=177;rcl=598616568
If the constant is too confusing, then having a callable version is fine too, like:
x = jax.device_put(x, jax.sharding.sharding_first_dim())
x = jax.device_put(x, jax.sharding.sharding_replicated())
I think this issue can be decomposed into 2 sub-parts:
[ ] Have jax.device_put
support multi-host sharding
[ ] Provide standard sharding, auto-computed for the common use-cases: Replicated and sharded, without having to create the mesh,...
The usage in your library you have pointed out does create a mesh. The problem with that is meshes usually come from mesh_utils not the way you have created it. It might work for your use case but providing such an API in JAX would not work for majority of stuff.
Have jax.device_put support multi-host sharding
Yes, that is on the table (I have a small prototype to do this) but it's more complicated than just device_put.
It might work for your use case but providing such an API in JAX would not work for majority of stuff.
Of course, it would not work for complex use-cases and LLM. The goal of this issue is not to fix all use-cases, just that the 2 most standard uses cases works out-of-the-box, meaning:
I believe those 2 uses cases already cover many use-cases. Flax for instance provide a flax.jax_utils.unreplicate
for pmap, which was widely used: https://source.corp.google.com/search?q=flax.*replicate
So it should be a good signal there's a need for this kind of high level convenience
The jax sharding API is quite flexible, but for the very common use-case, it require too many lines of code:
For example, if we take the orbax tutorial: https://orbax.readthedocs.io/en/latest/orbax_checkpoint_101.html#a-standard-recipe
Understanding what this code does require to parse each line individually and try to reverse engineer what was the intent. It's not trivial at first look that this code is doing data parallelism.
To support multi-host in a generic way is even more verbose and very difficult to parse: https://jax.readthedocs.io/en/latest/_autosummary/jax.make_array_from_single_device_arrays.html
Instead,
jax
should provide default sharding for the most common use-case:Something like:
Contrary to the first examples, this code is immediately understandable, even people coming from pytorch that don't know anything about Jax and sharding.