jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.27k stars 2.78k forks source link

Jax sharding API is too verbose #19360

Open Conchylicultor opened 9 months ago

Conchylicultor commented 9 months ago

The jax sharding API is quite flexible, but for the very common use-case, it require too many lines of code:

For example, if we take the orbax tutorial: https://orbax.readthedocs.io/en/latest/orbax_checkpoint_101.html#a-standard-recipe

sharding = jax.sharding.NamedSharding(
    jax.sharding.Mesh(jax.devices(), ('model',)),
    jax.sharding.PartitionSpec('model',)
)
x = jax.device_put(x, sharding)

Understanding what this code does require to parse each line individually and try to reverse engineer what was the intent. It's not trivial at first look that this code is doing data parallelism.

To support multi-host in a generic way is even more verbose and very difficult to parse: https://jax.readthedocs.io/en/latest/_autosummary/jax.make_array_from_single_device_arrays.html

local_shape = local_array.shape
global_mesh = Mesh(np.array(jax.devices()), ('devices'))
global_shape = (jax.process_count() * local_shape[0], ) + local_shape[1:]
arrays = jax.device_put(
    np.split(local_array, len(global_mesh.local_devices), axis = 0),
    global_mesh.local_devices,
)
sharding = jax.sharding.NamedSharding(global_mesh, P(('devices'), ))
array = jax.make_array_from_single_device_arrays(global_shape, sharding, arrays)

Instead, jax should provide default sharding for the most common use-case:

Something like:

x = jax.device_put(x, jax.sharding.SHARED_FIRST_DIM)
x = jax.device_put(x, jax.sharding.REPLICATED)

Contrary to the first examples, this code is immediately understandable, even people coming from pytorch that don't know anything about Jax and sharding.

yashk2810 commented 9 months ago

I think we can introduce some wrappers to create arrays like that.

Something like: jax.make_array_from_data_parallel_config(local_shape, mesh, np_array) which would basically do what the docstring does.

You do need devices to device_put so just specifying jax.sharding.REPLICATED is not enough. But a simple wrapper around it should work out. WDYT?

Conchylicultor commented 9 months ago

Thanks you for the answer!

Something like: jax.make_array_from_data_parallel_config(local_shape, mesh, np_array) which would basically do what the docstring does.

Without seeing the end-to-end code (with sharding, local_shape,... definition), it's difficult to compare. I'm not sure this would fix the issue:

You do need devices to device_put so just specifying jax.sharding.REPLICATED is not enough. But a simple wrapper around it should work out. WDYT?

We implemented this API in our codebase, so it is possible: http://google3/third_party/py/kauldron/evals/evaluators.py;l=177;rcl=598616568

If the constant is too confusing, then having a callable version is fine too, like:

x = jax.device_put(x, jax.sharding.sharding_first_dim())
x = jax.device_put(x, jax.sharding.sharding_replicated())

I think this issue can be decomposed into 2 sub-parts:

[ ] Have jax.device_put support multi-host sharding [ ] Provide standard sharding, auto-computed for the common use-cases: Replicated and sharded, without having to create the mesh,...

yashk2810 commented 9 months ago

The usage in your library you have pointed out does create a mesh. The problem with that is meshes usually come from mesh_utils not the way you have created it. It might work for your use case but providing such an API in JAX would not work for majority of stuff.

Have jax.device_put support multi-host sharding

Yes, that is on the table (I have a small prototype to do this) but it's more complicated than just device_put.

Conchylicultor commented 9 months ago

It might work for your use case but providing such an API in JAX would not work for majority of stuff.

Of course, it would not work for complex use-cases and LLM. The goal of this issue is not to fix all use-cases, just that the 2 most standard uses cases works out-of-the-box, meaning:

I believe those 2 uses cases already cover many use-cases. Flax for instance provide a flax.jax_utils.unreplicate for pmap, which was widely used: https://source.corp.google.com/search?q=flax.*replicate So it should be a good signal there's a need for this kind of high level convenience