jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.56k stars 2.81k forks source link

How do you remat GSPMD inserted all-gathers? #25010

Open ptoulme-aws opened 2 days ago

ptoulme-aws commented 2 days ago

Problem: I have some Jax code that does sequence parallel, so somewhat similar to this

activation = jax.lax.with_sharding_constraint(activation, NamedSharding(mesh, PartitionSpec('data', 'tensor', None))
activation = norm(activation)
activation =  jax.lax.with_sharding_constraint(activation, NamedSharding(mesh, PartitionSpec('None, 'tensor', None))
# I want to remat this one ^
activation = attention(activation)

I have tried everything I can to remat the activation directly before attention, including Jax policies, explicitly using jax checkpoint on that exact tensor, but nothing to seems to make it remat. The activation directly before attention is a GSPMD inserted all-gather on the sequence dimension (dim=0).

I ended up writing an XLA pass to rematerialize large all-gathers and submitted a PR. https://github.com/openxla/xla/pull/19163

Question: Is this possible to do from Jax end or is my pass really needed?

mattjj commented 2 days ago

Thanks for the question.

No, I don't think a new pass is needed.

As I understand it, the standard way to spell this is to us a remat policy to mark the with_sharding_constraint which induces the allgather as not-saveable. One way to do that would be to use save_only_these_names and to only name other arrays (that are either upstream of the allgather-inducing with_sharding_constraint, or downstream of the operations that use the output of attention). Following your snippet, that might look something like:

activation = jax.lax.with_sharding_constraint(activation, NamedSharding(mesh, PartitionSpec('data', 'tensor', None))
activation = checkpoint_name(norm(activation), 'scattered_activations')
activation =  jax.lax.with_sharding_constraint(activation, NamedSharding(mesh, PartitionSpec('None, 'tensor', None))
activation = attention(activation)

together with a save_only_these_names policy that mentions 'scattered_activations' or something upstream of it.

Did you try something like that? If you already tried it, we should put together a minimal example to debug what's going on.