iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.56k stars 571 forks source link

[GPU] Clustered Subgroup Reduction #18142

Open Groverkss opened 1 month ago

Groverkss commented 1 month ago

Request description

Motivation

A pattern we notice in flash attention kernels is:

A: tensor<16x16xf16>
B: tensor<16x16xf16>
C: tensor<16x16xf16>

D : tensor<16x16xf16> = matmul(A, B, C)
E : tensor<16x1xf16>  = reduce(D, dim=1)
F : tensor<16x16xf16> = broadcast(E, dim=1)

When optimizing for matmul intrinsics, for performance, we prefer to keep the computation here entierly in registers. So the distribution of data on threads cannot change. We primarily optimize for matmul intrinsics on GPUs, so the thread distribution follows that.

Accordingly, the thread distribution for tensor D follows that of output of a matmul intrinsic. An example of such a thread distribution for a 16x16 shape is each thread carrying a vector<1x4xf16>, distributed over a 16x4 thread grid.

The data can be though to be distributed on threads as (here the numbers represent thread ids, and the matrix over which they are distributed is the tensor D):

[0 0 0 0 1 1 1 1 2 2 2 2 3 3 3 3]
[4 4 4 4 5 5 5 5 6 6 6 6 7 7 7 7]
[8 8 8 8 9 9 9 9 ...            ]
.
.
.
[]

When reducing this along dimension = 1, we want to be doing multiple reductions in parallel in "clusters" of 4 threads. (An element carried by multiple thread ids is represented here by commas)

[0 0 0 0 1 1 1 1 2 2 2 2 3 3 3 3] --> [0, 1, 2, 3]
[4 4 4 4 5 5 5 5 6 6 6 6 7 7 7 7] --> [4, 5, 6, 7]
...

We will call such a reduction, a "clustered" reduction.

How we do it today

Today, we directly emit a bunch of gpu.shuffle ops and do the entire reduction lowering in one shot:

https://github.com/iree-org/iree/blob/main/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp#L402

Better lowering for clustered reductions

There is an existing operation in MLIR's gpu dialect, which can represent a reduction across threads in a subgroup:

https://mlir.llvm.org/docs/Dialects/GPU/#gpusubgroup_reduce-gpusubgroupreduceop

The limitation of this operation is that it uses all available threads in a subgroup to do the reduction, which means we cannot do a clustered reduction.

We would like to add support to this operation to do such clustered reductions.

Tasks

Useful Links

Current subgroup_reduce lowering: https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp GPU Dialect documentation: https://mlir.llvm.org/docs/Dialects/GPU/#gpusubgroup_reduce-gpusubgroupreduceop IR definition for gpu.subgroup_reduce : https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td#L1196

What component(s) does this issue relate to?

MLIR, Compiler

Additional context

No response

andfau-amd commented 1 month ago

I'm thinking of adding this part to the description of gpu.subgroup_reduce:

If a cluster_size is provided, the subgroup is divided into clusters of cluster_size lanes each, a reduction is done for all lanes of each cluster (in parallel), and the result is equal for all lanes in a cluster.

The operation of this is kind of difficult to succinctly explain. Maybe I would need to specify the particular way the clusters are laid out across the threads?

Groverkss commented 1 month ago

I'm thinking of adding this part to the description of gpu.subgroup_reduce:

If a cluster_size is provided, the subgroup is divided into clusters of cluster_size lanes each, a reduction is done for all lanes of each cluster (in parallel), and the result is equal for all lanes in a cluster.

The operation of this is kind of difficult to succinctly explain. Maybe I would need to specify the particular way the clusters are laid out across the threads?

I think the description sounds good for now. You can just mention that the clustered are divided in a contiguous manner.

andfau-amd commented 3 weeks ago

I think I have a finished implementation of the MLIR part now: https://github.com/llvm/llvm-project/pull/104851

andfau-amd commented 3 weeks ago

When I got to trying to update GPUNestedLayoutDistributionPatterns.cpp, me and @Groverkss realised that being able to specify cluster sizes isn't enough, we also need to be able to specify cluster strides (so, non-contiguous clusters). So I guess we will have to add a second attribute to MLIR. I have a good idea for what this should look like so it shouldn't be a problem, just a shame this was missed earlier.

andfau-amd commented 1 week ago

I made a patch to add "cluster strides" to MLIR (https://github.com/llvm/llvm-project/pull/107142), and I've successfully prototyped using this in GPUNestedLayoutDistributionPatterns.cpp downstream; I was able to get the existing test to pass unmodified when I hacked IREE to apply the upstream expansion. So I think this is going well. I've also made https://github.com/llvm/llvm-project/pull/107134 as I'll need that downstream.

andfau-amd commented 1 week ago

Second round of MLIR changes is now merged.

Using this in IREE will require a few different commits (will update this list as I go along):

andfau-amd commented 1 day ago

Main PR is up at https://github.com/iree-org/iree/pull/18515 and I made an issue for tracking potential follow-up cleanup work (https://github.com/iree-org/iree/issues/18516); I realised I shouldn't block this issue on it.