Open ptoulme-aws opened 2 days ago
Thanks for the question.
No, I don't think a new pass is needed.
As I understand it, the standard way to spell this is to us a remat policy to mark the with_sharding_constraint
which induces the allgather as not-saveable. One way to do that would be to use save_only_these_names
and to only name other arrays (that are either upstream of the allgather-inducing with_sharding_constraint
, or downstream of the operations that use the output of attention
). Following your snippet, that might look something like:
activation = jax.lax.with_sharding_constraint(activation, NamedSharding(mesh, PartitionSpec('data', 'tensor', None))
activation = checkpoint_name(norm(activation), 'scattered_activations')
activation = jax.lax.with_sharding_constraint(activation, NamedSharding(mesh, PartitionSpec('None, 'tensor', None))
activation = attention(activation)
together with a save_only_these_names
policy that mentions 'scattered_activations'
or something upstream of it.
Did you try something like that? If you already tried it, we should put together a minimal example to debug what's going on.
Problem: I have some Jax code that does sequence parallel, so somewhat similar to this
I have tried everything I can to remat the activation directly before attention, including Jax policies, explicitly using jax checkpoint on that exact tensor, but nothing to seems to make it remat. The activation directly before attention is a GSPMD inserted all-gather on the sequence dimension (dim=0).
I ended up writing an XLA pass to rematerialize large all-gathers and submitted a PR. https://github.com/openxla/xla/pull/19163
Question: Is this possible to do from Jax end or is my pass really needed?