openxla / xla

A machine learning compiler for GPUs, CPUs, and ML accelerators
Apache License 2.0
2.58k stars 403 forks source link

[xla: GPU] Multiple ppermutes are using same CUDA stream #10640

Open DwaraknathT opened 6 months ago

DwaraknathT commented 6 months ago

Hey all, I am working on a custom collective matmul implementation in jax to overlap an all gather with a matmul. I noticed that my bidirectional send/recv calls are running on the same stream, and therefore, end up being executed serially (please see attached profile image). Is it possible to manually assign different ppermute calls to different cuda streams?

image

Thank you!

cheshire commented 6 months ago

I'd guess it's a bug against JAX actually to provide an API for such control? Or both?

DwaraknathT commented 6 months ago

I'd guess it's a bug against JAX actually to provide an API for such control? Or both?

My apologies, I don't quite follow, are you saying jax shouldn't provide an API for such cases? How do you think I can do what I'm trying to achieve? My ppermute delay is basically giving me no advantage of compute overlap with communication.

cheshire commented 6 months ago

If you'd like JAX to provide such an API, then this should be a feature request on JAX?

DwaraknathT commented 6 months ago

If you'd like JAX to provide such an API, then this should be a feature request on JAX?

Ah, I don't necessarily want this as a jax API. XLA should already know that the forward and backward ppermutes can be done in parallel on different streams no?

My question more broadly was, how can I reduce the time taken by the serial ppermutes.

hawkinsp commented 6 months ago

A couple of notes: you can get a multistreamed collective matmul via the SPMD partioner if you set: --xla_gpu_multi_streamed_einsum_window, –-xla_gpu_threshold_for_windowed_einsum_mib=<small_value>

For the lower-level version, there are two asks:

I suspect that starting with "most manual" is the right thing.

cheshire commented 5 months ago

@hawkinsp Makes sense to me to go with the manual approach first ! For someone relatively new to scaling, could you sketch out roughly how those could propagate through the SPMD partitioner? Those should go together with shardings?

frgossen commented 4 months ago

I guess sharding could just propagate the stream attributes, which we already have in XLA. I guess that would need to be added to JAX and we'd have to make sure the annotation doesnt get lost in XLA.

For the more automatic approach, yes XLA should really do that. @golechwierowicz

bixia1 commented 1 month ago

We currently only expose one gpu stream to run the scheduler for running asynchronous collective-permute, so, having the two pairs of send/recv running sequentially is by the current design. NVIDIA has implement xla/gpu collective matmul, see the discussion thread.

Here are some notes about the flags to enable the optimization. Is this something you want to try with and provide feedback for, instead of implementing your own custom collective matmul?

--xla_gpu_threshold_for_windowed_einsum_mib= : this controls for which sizes of gemm, CM will be enabled
--xla_gpu_multi_streamed_windowed_einsum: this controls whether we want to unroll the CM loop
some other features that would improve perf:
PGLE will definitely help
--xla_gpu_use_memcpy_local_p2p: this will use cudamem copy p2p instead of calling nccl p2p for intra-node communications