jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.12k stars 2.76k forks source link

shard_map *much* faster than pjit for simple data parallelism #19657

Open kvablack opened 8 months ago

kvablack commented 8 months ago

Description

I'm trying to scale up some transformer training (currently at ~400m params), and as such I've been playing around with various ways to save memory and improve performance. On a whim, I tried replacing my jax.jit(in_shardings=..., out_shardings=...) setup for data parallelism with jax.experimental.shard_map, as so:

    from jax.experimental import mesh_utils, shard_map

    mesh = jax.sharding.Mesh(mesh_utils.create_mesh([jax.device_count()]), ["dp"])
    dp_spec = jax.sharding.PartitionSpec("dp")
    rep_spec = jax.sharding.PartitionSpec()
    dp_sharding = jax.sharding.NamedSharding(mesh, dp_spec)
    rep_sharding = jax.sharding.NamedSharding(mesh, rep_spec)

    @partial(
        jax.jit,
        in_shardings=(rep_sharding, dp_sharding),
        out_shardings=(rep_sharding, rep_sharding),
        donate_argnums=0,
    )
    #### ADDED THE FOLLOWING 7 LINES ####
    @partial(
        shard_map.shard_map,
        mesh=mesh,
        in_specs=(rep_spec, dp_spec),
        out_specs=(rep_spec, rep_spec),
        check_rep=False,
    )
    ####
    def train_step(state: TrainState, batch: Data):
        rng, dropout_rng = jax.random.split(state.rng)
        (loss, info), grads = jax.value_and_grad(loss_fn, has_aux=True)(
            state.model.params, batch, dropout_rng, train=True
        )
        #### AND THE FOLLOWING LINE ####
        loss, info, grads = jax.lax.pmean((loss, info, grads), axis_name="dp")
        ####
        new_state = state.apply_gradients(grads=grads, rng=rng)
        return new_state, info

and I immediately saw a 2.8x (!) speedup. The reason why this is a problem is because I would like to move on to more advanced parallelism techniques (tensor parallel, fully-sharded data parallel, etc) but it seems like it would be prohibitively difficult to write these manually using shard_map. However, if I continue using pjit's automatic partitioning, I worry that I'm leaving a bunch of performance on the table. I would think the automatic partitioner would be able to produce code with more or less equal performance in this very simple case.

Here are the debugging steps I've tried so far:

I've attached the HLO below. I would really appreciate any guidance on this, thanks!

no-shard-map.txt shard-map.txt

What jax/jaxlib version are you using?

0.4.23

Which accelerator(s) are you using?

TPUv4

Additional system info?

Python 3.10.12, tpu-vm-v4-base

NVIDIA GPU info

No response

jakevdp commented 8 months ago

Hi - I'm having trouble understanding your question. It sounds like you're comparing two implementations, but you've only shown us one implementation. Could you edit your question to show the code for both approaches?

kvablack commented 8 months ago

@jakevdp sorry, I've edited my question to hopefully make things more clear. The only difference between the two implementations is the addition of the 8 lines indicated (the shard_map itself and the corresponding pmean).

jakevdp commented 8 months ago

Thanks! Assigning to @yashk2810, who might have some insights here.

yashk2810 commented 8 months ago

Don't you need to jnp.mean for the jit version (without shard_map)?

kvablack commented 8 months ago

@yashk2810 The jnp.mean happens inside the loss function (a scalar is returned).

yashk2810 commented 8 months ago

I don't see that loss function :)

Can you create a minimal reproducer that we can run?

kvablack commented 8 months ago

Sure thing, here's my repro. With the shard_map version, I get 1.09 s/it, and with no shard_map, I get 2.95 s/it. This is on a v4-8 TPU VM.

yashk2810 commented 3 days ago

Hey -- sorry for the late reply, can you try with the latest jax and jaxlib version? (or better try with nightly)

Also can you tell me what TPU you were using? TPUv4 but how many devices?

kvablack commented 3 days ago

This was a v4-8 VM (smallest you can get, I think). I no longer have easy access to TPUs, but I replicated the issue with jax[cuda12]==0.4.33 on an 8xH100 DGX machine. With no shard map, I get 1.10s/it, and with shard map, I get 1.70it/s.