openxla / xla

A machine learning compiler for GPUs, CPUs, and ML accelerators
Apache License 2.0
2.61k stars 409 forks source link

[XLA:GPU] Check failed in collective_pipeliner when using gradient accumulation with non-unrolled loop #14332

Open qGentry opened 3 months ago

qGentry commented 3 months ago

Hi, I have following setup:

I'm using following flags:

--xla_gpu_graph_level=0 
--xla_gpu_enable_triton_gemm=false 
--xla_gpu_enable_command_buffer= 
--xla_gpu_enable_latency_hiding_scheduler=true 
--xla_gpu_enable_all_gather_combine_by_dim=false 
--xla_gpu_enable_reduce_scatter_combine_by_dim=false 
--xla_gpu_enable_pipelined_all_gather=true 
--xla_gpu_enable_pipelined_reduce_scatter=true 
--xla_gpu_enable_pipelined_all_reduce=true 
--xla_gpu_enable_pipelined_collectives=false 
--xla_gpu_enable_while_loop_double_buffering=true 
--xla_gpu_enable_highest_priority_async_stream=true 
--xla_gpu_disable_async_collectives=collectivebroadcast,alltoall,collectivepermute

This works correctly and indeed hide layers' weights all-gather and gradient reduce-scatter behind computations.

Problems are starting to arise when I try to use gradient accumulation in this setup. It is implemented like this:

    grads_sum = jax.tree_map(jnp.zeros_like, train_state.params)
    train_state, grads_sum = jax.lax.fori_loop(
        lower=0,
        upper=num_minibatches_in_batch,
        body_fun=_loop_body,
        init_val=(train_state, grads_sum),
        unroll=False,
    )

    mean_grads = jax.tree_map(lambda x: x / num_minibatches_in_batch, grads_sum)

When I set gradient accumulation factor (num_minibatches_in_batch in this snippet) to value greater than 1, I'm getting following error during compilation:

2024-07-01 12:57:35.488299: F external/xla/xla/service/collective_pipeliner.cc:675] Check failed: last_cloned != nullptr (0 vs. nullptr)

Here is --xla_dump_to result: xla_dump.tgz

One important fact here is that if I set unroll in jax.lax.fori_loop to True, then there is no compilation error and everything works. But this obviously leads to additional memory usage proportional to gradient accumulation factor so this hack doesn't seem to be viable.

My hypothesis is that when using --xla_gpu_enable_while_loop_double_buffering=true with pipelined collectives and latency hiding scheduler, XLA compiler tries to double buffer this fori_loop which is actually undesired behavior.

Basically, there are two problems:

I've tested this on JAX 0.4.29 and 0.4.30.

qGentry commented 3 months ago

related JAX issue: https://github.com/google/jax/issues/22210

qGentry commented 3 months ago

Actually, this problem persists even with --xla_gpu_enable_while_loop_double_buffering=false, so maybe it is not source of the problem.

Tixxx commented 2 months ago

For the compilation error in collective_pipeliner, can you try with xla_gpu_run_post_layout_collective_pipeliner=false ?

rosiezou commented 2 months ago

Hi Filipp, could you try TJ's suggestion and update this issue with the results and any errors if applicable?

qGentry commented 2 months ago

Hi guys, looks like this flag was added in very recent commit and has not been added to JAX latest release (0.4.30). I'll wait for JAX 0.4.31 to test it. Thank you!

qGentry commented 3 weeks ago

Looks like JAX 0.4.31 has broken GPU support in docker containers, so I'll wait for 0.4.32

nouiz commented 1 week ago

JAX 0.4.33 is released. Does it fix your issue?