iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.6k stars 584 forks source link

Prototype collective operations and a NCCL integration for CUDA. #9580

Closed benvanik closed 1 year ago

benvanik commented 2 years ago

(this is a brain dump for discussion and referencing code as we start to plumb things; design will change as we learn, and we should iterate/file more specific issues/etc as we go)

The goal is to have collective operations supported from frontends (e.g. mhlo.all_reduce) all the way through to runtime backends without carrying along some of the bad assumptions from those frontends (like mhlo requiring all the parameters as static attributes instead of SSA values we can compute). We'll be more general and then can push on those layers to generalize vs just carrying along those limitations.

As a starting point NCCL is a good example of a minimalistic API for these kind of operations in that most of the MPI operations can be implemented using the NCCL primitives which aligns well with our philosophy and something we can use as design inspiration as well as an initial prototype. Supporting the collectives efficiently across other targets (CPU in particular) should end up being similarly constructed with target-specific specializations of the common ops we can do as follow-ons.

The basic execution model we're going for is classic SPMD where each IREE invocation will be running on some subset of the data and performing collective ops as part of a team of invocations. We also have the ability to perform an additional level of SPMD across queues within a single invocation, allowing us the flexibility of high-level SPMD across nodes and low-level SPMD across device queues on different devices. For the initial work we can assume the high-level approach (user is doing distribution with mpirun-style stuff), though the low-level queue approach will allow us to better utilize NUMA systems without significant infra being required by users (multi-GPU and multi-CPU would appear as a single user invocation but we'd run the collectives across all devices). Materially the low-level SPMD is stream formation/partitioning and something we can extend whatever we build here to support in a fairly orthogonal way.

There are a few gotchas to the execution model we'll have to ensure with our lowerings namely around deterministic ordering. If we schedule a fork of two streams on the same queue with either containing collective ops we'd need to strictly order them. Today we don't schedule multiple overlapping streams so it should be fine - when we extend to support that we can do it by adding exclusive resource constraints and such.


NCCL's design is slightly different than MPI in that it has explicit barriers due to its asynchronous enqueuing behavior - this works well for us as we too are doing asynchronous enqueuing and model barriers. For our purposes we can assume that we will always be guarding collective operations in a begin-end pair thus creating atomic groups of operations that from the host perspective look like a single operation:

<other operations>
ncclGroupStart();
ncclBroadcast(sendbuff1, recvbuff1, count1, datatype, root, comm, stream);
ncclAllReduce(sendbuff2, recvbuff2, count2, datatype, comm, stream);
ncclGroupEnd();
<other operations>
ncclGroupStart();
<more collective ops>
ncclGroupEnd();
<other operations>

We could map each of these as distinct operations down to command buffers, but since the only legal operations in a group are collective ones and the exact sequence is critical to correctness we wouldn't actually gain anything from that decomposition. Instead, we can conceptually think of all of the collectives as a single atomic dispatch:

<other operations>
%0:2 = flow.dispatch.collectives[%count](%input0, %input1, %output0, ....) -> (%output0, tensor<...>) {
  %t0 = some_collective_dialect.broadcast %input0 -> %output0 // in-place in output0
  %t1 = some_collective_dialect.all_reduce %input1 -> tensor<...> // new result allocation
  flow.return %t0, %t1
}
<other operations>

We could then verify that the flow.dispatch.collectives op only contains compatible collective operations and avoid inlining constants/etc that would be difficult to handle. The workload capture of %count here (if not other things too) will allow the for the collectives to be partitioned again if we need to (across streams, etc) as well as lowered to normal parallelized workgroup dispatch. The tensor I/O ensures that the buffers that would be required are going to be allocated independently as with any other tensor storage. We may want to specially control the attributes of the storage buffers but that can be an extension to the stream.resource->hal.buffer lowering when we need it.

There are some details to work out here with how to handle workloads/fusion but the high-level goal is to not harm fusion when multi-device collectives aren't used (tiny bare-metal stuff) by lowering them to fusable linalg stuff early while also enabling multi-device collectives to be optionally supported at runtime. That is to say we will likely have a target configuration attribute defining whether collectives are worth modeling like this and when disabled all would lower to normal broadcast/reduction/etc linalg ops like today and otherwise go down this path. When enabled we still have the option in the backends to generate both the multi-device-aware variants as well as fallback variants, allowing us to support mixed execution when targeting multiple devices or respond to runtime situations (only 1 GPU present, NCCL not supported, etc). This also helps our testing story by allowing us to run the same compiled program across all devices with no other structural changes. This goal impacts how we layer in the collective calls: we don't want dedicated command buffer calls or host-side extensions, as those presuppose the existence and validity of those at runtime. Instead we can rely on the fact that command buffer dispatches are just fancy ways of saying "run this extern function with these buffers and arguments" and swap that function implementation: we just need buffers and integer arguments to call all the collective APIs we'd want to use.


For backends like CUDA+NCCL or CPU+MPI that do support collectives we can declare the operations in a backend-specific representation. For CUDA this would likely mean something like a NCCLGroupDef that contained some list (up to 2048, according to NCCL docs) of NCCLOpDef entries. The CUDAExecutableDef could then have both a PTX blob and collective definitions, and it's up to the command buffer at dispatch time to decide whether to launch a kernel or call into NCCL by literally checking if the dispatch target has NCCL data and running a basic interpreter. The NCCLOpDef may be specialized like NCCLSendOpDef and should have the required metadata to get the mix of buffers (pulled from bound descriptors) and integers (which may be push constants or indirect values loaded from buffers) needed to make the calls:

// Reference to a bound buffer, defined by the hal.interface.
struct NCCLDescriptorDef {
  set:uint32;
  binding:uint32;
}
// Byte offset into push constants set on the command buffer, defined by the hal.interface.
table NCCLConstantParameterDef {
  byte_offset:uint32;
}
// Indirect load - maybe not something we can support?
table NCCLIndirectParameterDef {
  source:NCCLDescriptorDef;
  offset:NCCLParameterDef;
}
union NCCLParameterDef {
  NCCLConstantParameterDef,
  NCCLIndirectParameterDef,
}
// ncclSend(const void* sendbuff, size_t count, ncclDataType_t datatype, int peer, ncclComm_t comm, cudaStream_t stream)
table NCCLSendOpDef {
  send_buffer:NCCLDescriptorDef;
  count:NCCLParameterDef;
  data_type:uint32;
}
union NCCLOpDef {
  NCCLSendOpDef,
  ...
}
// ncclGroupStart/End scope with ops inside of it
table NCCLGroupDef {
  ops:[NCCLOpDef];
}

So an entry point may have both a PTX blob (fallback, no NCCL) and a group def with the parameters extracted from the IR hal.interface.binding.subspan and hal.interface.constant.load ops:

NCCLGroupDef:
  NCCLSendOpDef:
    send_buffer: {set=0, binding=1}
    count: NCCLConstantParameterDef {byte_offset=24}
    data_type: 4

We could look at abstracting some of these and sharing them across other GPU backends, but for now having this per-backend will let us iterate and appropriately handle them until we've got more experience with how to compile for them. On the CPU we can use executable library imports to do this by injecting MPI-like API symbols at runtime that the LLVM generated executables can call into; in that case we don't need flatbuffers and can make any calls we want - but we would want to define the abstract import API so that multiple platforms can run the same binaries even if they have different underlying implementations. It's all follow-on work that would share all of the dialect/dispatch formation/etc parts common to the NCCL path.


Sketching a workflow for the initial prototype:

We should be able to start with the mhlo->collective dialect lowering and then add the flow op. From there having a generic pass shared across backends for lowering the collective ops to tiled/distributed/mumble linalg/vectors/etc that the codegen process can handle would let us run end-to-end even before integrating NCCL. After that plumbing through the NCCL flatbuffer and runtime stuff is separable.

powderluv commented 2 years ago

@julianwa can I add this to NOD Model Coverage Milestone (Can't seem to add to JAX training too).

sogartar commented 2 years ago

I am taking a stab at this problem. It is a prototype at the moment. I defined an incomplete CCL dialect. I want to implement the NCCL backend. When implementing the SPMD model of execution the scheduling of invocations across devices must be synchronized so that the order of invocations is the same across all VMs. If the order is different in the scenario of multiple MPI communicators the VMs may deadlock. For example there are 2 VMs with independent schedulers. At some time point both of them schedule different invocations and then they do not schedule anything else as there is no more memory. Lets say the 2 invocations use 2 different communicators that share the 2 ranks. The 2 invocations will deadlock in this scenario. @benvanik, could you share a bit more on your view on how this synchronization could be achieved? I can see a few solutions here.

  1. Use a deterministic scheduler. All IREE runtimes across all nodes use the same configuration and they run the same MLIR program. This should guarantee the same order of invocations. Is there a deterministic scheduler in IREE?
  2. I added a chain type to the CCL dialect analogous to TFRT. Each operation accepts a chain and returns a chain so you can strictly order the ops and work around the stateful nature of the communicators. This technique could be used to chain all CCL calls on a single chain, which will enforce a strict order of all CCL dispatch regions and solve the deadlock problem. There are drawbacks to this approach. It will restrict and hurt operation reordering and parallelism.
  3. Have a scheduler that coordinates the execution order between runtimes. For example a root runtime decides the schedule and sends it to the rest. The communication latency when sending the schedule may stall execution if not enough invocations are scheduled in advance.
benvanik commented 2 years ago

Hi @sogartar! This is likely to require higher-bandwidth than an issue and I know there are others looking into this as well who may want to participate. I can try to provide some background in here for our approach, how it fits into the larger roadmap, and the concerns we have around modeling of collectives at higher levels.

The approach you mention is not quite applicable to the layer IREE operates at as (today) IREE is designed to be used within some executor and not itself be the executor. An IREE program is the "SP" and the hosting application provides the "MD" :) This is similar to the relationship between TF/JAX and XLA.

There's really three distinct problems: how to annotate a user program for automatic SPMD sharding, how to SPMD-ize an annotated program, and how to compile/execute SPMD programs. IREE focuses on the latter today but would like to grow to handle the higher layers over time (hopefully with community involvement).

For what we (likely) want for use with IREE to annotate partitioning see "GShard": https://arxiv.org/abs/2006.16668 - section 3.2 discusses their replicate, split, and shard ops that annotate tensors and would allow for program transformations into SPMD form in a way consistent with how MLIR/IREE work. There's a follow-on paper "GPSMD" that talks more about how it operates/generalizes: https://arxiv.org/abs/2105.04663. The critical thing here is that the ordering is based on data dependencies and not additional side-channel information (async tokens) and that is true all the way to runtime as execution order is consistently derived from those data dependencies at each lowering step. Fusion, grouping, and slicing happen at several layers in the compiler and that becomes very difficult the earlier such side-channel constraints are added. Our near-term plan is to see if we can reuse GPSMD as it exists in XLA to get automatically generated SPMD programs out we can use (though of course we'd like to have an MLIR implementation that worked for all frontends over time!).

Moving one step lower in the stack (SPMD scheduling): IREE programs are defined imperatively and not as a graph and as such IREE has no concept of global schedules or a global scheduler - only local schedules controlled entirely by IREE within an invocation. Hosting applications/framework layers decide how to schedule function invocations and if MPI/collectives are used at that layer they must ensure they schedule the invocations correctly just as if a human was authoring the program. That is to say IREE has no unique behavior with respect to execution that is not also present in normal code: a malformed program such as if x then do_collectives() run on multiple nodes where x is non-uniform will deadlock.

And finally at the lowest level (execution): we do have the potential to introduce local non-determinism within an SPMD IREE program when we repartition work into multiple concurrently-executable streams or group concurrently-executable work within a single stream. Collective dispatches will need to be marked as barriers to have partitioning avoid this, though today we're not partitioning into multiple streams and within a single stream the ordering is allowed to differ (such as in NCCL when using either CUDA stream multiplexing with events or CUDA graphs). When we add partitioning into multiple logical streams we will order the stream execution to preserve this. I think the exact design of this is going to need some iteration but it's akin to the GSPMD nested parallelism: multiple collectives may be in-flight across different sets of nodes within the same program and global barriers are too coarse for that.

The design discussed above is compatible with IREE and this roadmap as more layers get built. Our plan to get started is to use JAX (jax.pmap and such https://jax.readthedocs.io/en/latest/jax-101/06-parallelism.html) as that higher-level layer that controls SPMD scheduling by lowering from mhlo collective operations. Working forward from that we would ideally have some community-supported primitives upstream (possibly in TCP) that would allow multiple frontends to lower into it. These primitives would look like the mhlo ones (tensors only, no async handles). Finally it'd be great to see an upstream MLIR dialect that provided similar expressiveness to GSPMD (replicate/split/shard) and transformations to lower that into the SPMD form. Hopefully all of this lives outside of IREE outside of our actual SPMD execution primitives, though, and since that lowest execution layer is always going to be required that's where we're looking to start with this issue.

HTH, and if such an approach is interesting to you we'd be interested to collaborate (I know @mattwalsh and @okkwon have been ramping up on this specifically but am not sure how far they've gotten).

benvanik commented 2 years ago

All of that said, the original design outlined in the issue had the goal of being quick and easy to implement and then delete as we learned more about how things fit together. I still think it's worth doing in order to get everything wired up end-to-end and I fully expect there will be significant iteration (possibly more flow and stream support ops) - the primary goal was to get to something concrete and avoid design paralysis :)

sogartar commented 2 years ago

When I was talking about the deadlocking problem in my previous comment I didn't use the correct terminology. What I meant was a scenario where you have 2 collective communication dispatch regions in one MLIR program. They have no dependency between each other, so they may execute in parallel. At runtime due to resource exhaustion they can not really run in parallel. It seems there are 2 options.

  1. Global barriers surrounding each collectives dispatch region to guarantee that all replicas are in sync at the point of communication. This approach is agnostic to scheduling behavior.
  2. Make sure that the order of execution of dispatch regions is always the same. For example if we have the below code.
    vm.call @hal.command_buffer.execution_barrier(...) : ...
    vm.call @hal.command_buffer.dispatch(...) : ...
    vm.call @hal.command_buffer.dispatch(...) : ...
    vm.call @hal.command_buffer.execution_barrier(...) : ...

    Can we have a scheduling behavior such that the first dispatch starts executing always first? I was wondering can there be nested dispatch regions?

Regarding the chain/tokens approach. If they are to be dropped then all ops must use their own communicator. We may get to thousands of communicators in models with many ops. I don't know how this will affect performance. In HLO collective ops have attributes to specify replica groups that are lists of integer replica IDs. They also have channel handles as attribute (see for example allreduce). Their combination is like a set of communicators. Dropping tokens is only applicable to SPMD. In the case of MPMD the order of collective ops must be matched across all compilations. I was hoping that the CCL dialect should not be restricted to only SPMD. The concern regarding SPMD and MPMD should lay in another part of the stack. In HLO most collective communication operations don't use tokens. The exceptions are send and receive.

My intuition is that the CCL dialect should not be concerned with sharding specs like the ones in HLO. There all ops have an attribute that specifies the sharding of the output tensor(s). This should be a concern of a layer above CCL where these sharding specs can be assigned by auto-sharding algorithms like GSPMD. After sharding specs have been assigned a separate pass translates each op to something that uses CCL ops + other compute ops. Unfortunately, this translation is a crosscutting concern that needs to know how to lower each sharded op that can appear at that level.

There is also the question how to handle user IO to the distributed model. This involves sharding and distribution of model parameters and input. Then collecting the ouput. If the compiler is auto-sharding then it should describe how the IO is sharded.

sogartar commented 2 years ago

That is to say we will likely have a target configuration attribute defining whether collectives are worth modeling like this and when disabled all would lower to normal broadcast/reduction/etc linalg ops like today and otherwise go down this path.

This would require to know the replicas in a communicator at compile time. It will be probably most convenient to have the replicas as operation attributes instead of passing a communicator as an operand.

How important is this feature to be able to run on one device? Usually someone at a higher level would manually or automatically shard the model. Can't they just compile the clean version before it is polluted with CC (collective communication) ops? I am not sure how well the sharded model version would be optimized. For example if you look at the GShard paper, A.4. How they decompose a convolution with the halo exchange. It seems there is a lot going on there. Even if you are able to easily substitute the CC ops, would the subsequent compilation going to produce fast code?

okkwon commented 1 year ago

I am investigating how JAX generates mlir with pmap. There are two ways to get an mlir:

  1. Using jit() and lower()
  2. Using lower() directly.

Here is the link for the python scripts and their result: https://gist.github.com/okkwon/d57e657b2a939d1356b754029ba2e895

Method 1) uses the whole input and output and the program uses replica_id to partially access the partitioned data while Method 2) provides an mlir already partitioned since the PJRT maintains the partitioned data, which is more scalable.

Anyway, with the very small example, we could build a tiny end-to-end test by only supporting mhlo.all_reduce.

allieculp commented 1 year ago

@okkwon Is this still active?

okkwon commented 1 year ago

@allieculp sorry for my late reply. This is the very first issue talking about the NCCL integration and many things have been done and there are other github issues for the rest of work, so I think we can close this issue.