google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.07k stars 2.66k forks source link

Fine-grained remat policy makes async/pipelined collectives execute in the main stream #22252

Open qGentry opened 3 days ago

qGentry commented 3 days ago

Description

Hi, I have following setup:

I'm using following flags:

--xla_gpu_graph_level=0 
--xla_gpu_enable_latency_hiding_scheduler=true 
--xla_gpu_enable_all_gather_combine_by_dim=false 
--xla_gpu_enable_reduce_scatter_combine_by_dim=false 
--xla_gpu_enable_pipelined_all_gather=true 
--xla_gpu_enable_pipelined_reduce_scatter=true 
--xla_gpu_enable_pipelined_all_reduce=true 
--xla_gpu_enable_pipelined_collectives=false 
--xla_gpu_enable_while_loop_double_buffering=true 
--xla_gpu_enable_highest_priority_async_stream=true 
--xla_gpu_all_reduce_combine_threshold_bytes=2147483648 
--xla_gpu_all_gather_combine_threshold_bytes=2147483648  
--xla_gpu_reduce_scatter_combine_threshold_bytes=2147483648
--xla_gpu_disable_async_collectives=collectivebroadcast,alltoall,collectivepermute

To speedup backward by fine-grained reduction of activations recomputation, I marked each dense layer's output in transformer block with specific name:

result = jax.lax.dot_general(
    inputs,
    kernel,
    dimension_numbers=((axis, contract_ind), ((), ())),
    precision=self.precision,
    preferred_element_type=self.accumulator_dtype,
)
result = jax.ad_checkpoint.checkpoint_name(result, self.activation_dot_name)

So, for example, in attention layer I have "dot_attention_query", "dot_attention_key", "dot_attention_value", "dot_attention_out".

And then I apply checkpoint policy on scanned function which accepts list of activation names to checkpoint:

def rematted_layer(layer):
    return nn.remat(
        layer,
        policy=jax.checkpoint_policies.save_only_these_names(
            *self.config.save_names_for_bwd
        ),
        prevent_cse=not self.config.scan,
    )

and then scan It over embeddings:

apply_block = rematted_layer(apply_block)
apply_block = nn.scan(
    apply_block,
    length=self.config.num_layers,
    variable_axes={
        "params": 0,
    },
    variable_broadcast=False,
    split_rngs={"params": True},
    metadata_params={nn.PARTITION_NAME: "layers"},
)
block = TransformerBlock(
    name="scan",
    config=self.config.block,
)
embeddings, _ = apply_block(block, embeddings, None)

If I set self.config.save_names_for_bwd to empty list (which is basically equivalent to "nothing_saveable" policy), then communications works correctly - all-gather/reduce-scatters/all-reduces are overlapped with computations, as can be seen on this perfetto trace:

Screenshot 2024-07-03 at 14 33 27

nothing_saveable.tgz

But as soon as I start to specify some names in self.config.save_names_for_bwd, for example,

    save_names_for_bwd:
      - dot_mlp_out
      - dot_attention_value
      - dot_attention_query
      - dot_attention_key

While these activations is indeed not recomputed during backward pass, all communications are executed in main stream without any overlapping with computations:

Screenshot 2024-07-03 at 14 35 06

save_only_these_names_trace.tgz

System info (python version, jaxlib version, accelerator, etc.)

>>> import jax; jax.print_environment_info()
jax:    0.4.30
jaxlib: 0.4.30
numpy:  1.24.3
python: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]
jax.devices (8 total, 8 local): [cuda(id=0) cuda(id=1) ... cuda(id=6) cuda(id=7)]
process_count: 1
platform: uname_result(system='Linux', node='ffisin-dev-8gpu', release='5.4.0-155-generic', version='#172-Ubuntu SMP Fri Jul 7 16:10:02 UTC 2023', machine='x86_64')

$ nvidia-smi
Mon Jul  1 13:21:57 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.12             Driver Version: 535.104.12   CUDA Version: 12.3     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA H100 80GB HBM3          On  | 00000000:8D:00.0 Off |                    0 |
| N/A   68C    P0             141W / 700W |   1074MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA H100 80GB HBM3          On  | 00000000:91:00.0 Off |                    0 |
| N/A   48C    P0             121W / 700W |   1074MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   2  NVIDIA H100 80GB HBM3          On  | 00000000:95:00.0 Off |                    0 |
| N/A   69C    P0             137W / 700W |   1074MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   3  NVIDIA H100 80GB HBM3          On  | 00000000:99:00.0 Off |                    0 |
| N/A   50C    P0             126W / 700W |   1074MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   4  NVIDIA H100 80GB HBM3          On  | 00000000:AB:00.0 Off |                    0 |
| N/A   68C    P0             142W / 700W |   1074MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   5  NVIDIA H100 80GB HBM3          On  | 00000000:AF:00.0 Off |                    0 |
| N/A   49C    P0             124W / 700W |   1074MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   6  NVIDIA H100 80GB HBM3          On  | 00000000:B3:00.0 Off |                    0 |
| N/A   68C    P0             143W / 700W |   1074MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   7  NVIDIA H100 80GB HBM3          On  | 00000000:B7:00.0 Off |                    0 |
| N/A   48C    P0             121W / 700W |   1074MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
+---------------------------------------------------------------------------------------+
qGentry commented 3 days ago

corresponding XLA issue: https://github.com/openxla/xla/issues/14397