google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.75k stars 2.71k forks source link

Check failed in collective_pipeliner when using gradient accumulation with non-unrolled loop #22210

Closed qGentry closed 1 month ago

qGentry commented 2 months ago

Description

Hi, I have following setup:

I'm using following flags:

--xla_gpu_graph_level=0 
--xla_gpu_enable_triton_gemm=false 
--xla_gpu_enable_command_buffer= 
--xla_gpu_enable_latency_hiding_scheduler=true 
--xla_gpu_enable_all_gather_combine_by_dim=false 
--xla_gpu_enable_reduce_scatter_combine_by_dim=false 
--xla_gpu_enable_pipelined_all_gather=true 
--xla_gpu_enable_pipelined_reduce_scatter=true 
--xla_gpu_enable_pipelined_all_reduce=true 
--xla_gpu_enable_pipelined_collectives=false 
--xla_gpu_enable_while_loop_double_buffering=true 
--xla_gpu_enable_highest_priority_async_stream=true 
--xla_gpu_disable_async_collectives=collectivebroadcast,alltoall,collectivepermute

This works correctly and indeed hide layers' weights all-gather and gradient reduce-scatter behind computations.

Problems are starting to arise when I try to use gradient accumulation in this setup. It is implemented like this:

    grads_sum = jax.tree_map(jnp.zeros_like, train_state.params)
    train_state, grads_sum = jax.lax.fori_loop(
        lower=0,
        upper=num_minibatches_in_batch,
        body_fun=_loop_body,
        init_val=(train_state, grads_sum),
        unroll=False,
    )

    mean_grads = jax.tree_map(lambda x: x / num_minibatches_in_batch, grads_sum)

When I set gradient accumulation factor (num_minibatches_in_batch in this snippet) to value greater than 1, I'm getting following error during compilation:

2024-07-01 12:57:35.488299: F external/xla/xla/service/collective_pipeliner.cc:675] Check failed: last_cloned != nullptr (0 vs. nullptr)

Here is --xla_dump_to result: xla_dump.tgz

One important fact here is that if I set unroll in jax.lax.fori_loop to True, then there is no compilation error and everything works. But this obviously leads to additional memory usage proportional to gradient accumulation factor so this hack doesn't seem to be viable.

My hypothesis is that when using --xla_gpu_enable_while_loop_double_buffering=true with pipelined collectives and latency hiding scheduler, XLA compiler tries to double buffer this fori_loop which is actually undesired behavior.

Basically, there are two problems:

I've tested this on JAX 0.4.29 and 0.4.30.

System info (python version, jaxlib version, accelerator, etc.)

>>> import jax; jax.print_environment_info()
jax:    0.4.30
jaxlib: 0.4.30
numpy:  1.24.3
python: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]
jax.devices (8 total, 8 local): [cuda(id=0) cuda(id=1) ... cuda(id=6) cuda(id=7)]
process_count: 1
platform: uname_result(system='Linux', node='ffisin-dev-8gpu', release='5.4.0-155-generic', version='#172-Ubuntu SMP Fri Jul 7 16:10:02 UTC 2023', machine='x86_64')

$ nvidia-smi
Mon Jul  1 13:21:57 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.12             Driver Version: 535.104.12   CUDA Version: 12.3     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA H100 80GB HBM3          On  | 00000000:8D:00.0 Off |                    0 |
| N/A   68C    P0             141W / 700W |   1074MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA H100 80GB HBM3          On  | 00000000:91:00.0 Off |                    0 |
| N/A   48C    P0             121W / 700W |   1074MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   2  NVIDIA H100 80GB HBM3          On  | 00000000:95:00.0 Off |                    0 |
| N/A   69C    P0             137W / 700W |   1074MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   3  NVIDIA H100 80GB HBM3          On  | 00000000:99:00.0 Off |                    0 |
| N/A   50C    P0             126W / 700W |   1074MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   4  NVIDIA H100 80GB HBM3          On  | 00000000:AB:00.0 Off |                    0 |
| N/A   68C    P0             142W / 700W |   1074MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   5  NVIDIA H100 80GB HBM3          On  | 00000000:AF:00.0 Off |                    0 |
| N/A   49C    P0             124W / 700W |   1074MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   6  NVIDIA H100 80GB HBM3          On  | 00000000:B3:00.0 Off |                    0 |
| N/A   68C    P0             143W / 700W |   1074MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   7  NVIDIA H100 80GB HBM3          On  | 00000000:B7:00.0 Off |                    0 |
| N/A   48C    P0             121W / 700W |   1074MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
+---------------------------------------------------------------------------------------+
qGentry commented 2 months ago

related XLA issue: https://github.com/openxla/xla/issues/14332

mattjj commented 1 month ago

Thanks for raising this.

I think it's an XLA:GPU issue, and we don't have any way to fix it from JAX.

That said, the hard-to-parsae error may be something we can get traction on from JAX... can you say a bit more about what would've helped in the error message? We attach Python source information to the HLO program, but it's up to XLA to raise errors that reference it... from JAX we could've at least told you which jitted function raised the compiler error, but I'm not sure if we have other information to provide...

mattjj commented 1 month ago

I think we should close this in favor of the XLA issue. Looks like it just got assigned yesterday!