To speedup backward by fine-grained reduction of activations recomputation, I marked each dense layer's output in transformer block with specific name:
result = jax.lax.dot_general(
inputs,
kernel,
dimension_numbers=((axis, contract_ind), ((), ())),
precision=self.precision,
preferred_element_type=self.accumulator_dtype,
)
result = jax.ad_checkpoint.checkpoint_name(result, self.activation_dot_name)
So, for example, in attention layer I have "dot_attention_query", "dot_attention_key", "dot_attention_value", "dot_attention_out".
And then I apply checkpoint policy on scanned function which accepts list of activation names to checkpoint:
If I set self.config.save_names_for_bwd to empty list (which is basically equivalent to "nothing_saveable" policy), then communications works correctly - all-gather/reduce-scatters/all-reduces are overlapped with computations, as can be seen on this perfetto trace:
While these activations is indeed not recomputed during backward pass, all communications are executed in main stream without any overlapping with computations:
Description
Hi, I have following setup:
I'm using following flags:
To speedup backward by fine-grained reduction of activations recomputation, I marked each dense layer's output in transformer block with specific name:
So, for example, in attention layer I have "dot_attention_query", "dot_attention_key", "dot_attention_value", "dot_attention_out".
And then I apply checkpoint policy on scanned function which accepts list of activation names to checkpoint:
and then scan It over embeddings:
If I set self.config.save_names_for_bwd to empty list (which is basically equivalent to "nothing_saveable" policy), then communications works correctly - all-gather/reduce-scatters/all-reduces are overlapped with computations, as can be seen on this perfetto trace:
nothing_saveable.tgz
But as soon as I start to specify some names in self.config.save_names_for_bwd, for example,
While these activations is indeed not recomputed during backward pass, all communications are executed in main stream without any overlapping with computations:
save_only_these_names_trace.tgz
System info (python version, jaxlib version, accelerator, etc.)