iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.5k stars 557 forks source link

Experimental Convolution and Matmul GPU Codegen #17148

Open qedawkins opened 2 months ago

qedawkins commented 2 months ago

Overview

Currently there are four main ways to generate code for matmuls across the LLVMGPU and SPIR-V backends.

  1. GPUMatmulSimt
    • Tiles K, creates copies + allocations with bufferization.alloc_tensor, and distributes copies introduced by bufferization with GPUDistributeSharedMemoryCopy.
    • Struggles with producer fusion due to using custom patterns for distribution of the shared memory copies instead of tile + fuse.
    • Achieves good access patterns for the loads from global -> shared by controlling the distribution of the shared memory copies explicitly.
  2. GPUMatmulTensorCore(MmaSync)
    • Tiles K, creates copies + allocations with bufferization.alloc_tensor, and does early bufferization and hard codes the thread layouts for mma.sync
    • Suffers from similar problems to the SIMT path, while also being inflexible with the mma intrinsics for other hardware.
  3. LLVMGPUVectorDistribute
    • Tiles K, creates copies + allocations with bufferization.alloc_tensor, but accesses the copies with vector transfers. Then distributes block-level vector code directly to threads.
    • Has trouble reconciling producer fusions and good access patterns as vectorization breaks up producers into atoms; requires additional analysis to determine how to distribute global -> shared.
    • Distribution patterns are quite involved and reinvent a lot of wheels (in the current implementation).
    • Is good at mapping to unusual layouts required by mma intrinsics.
    • Is good at consumer fusions.
  4. TransformDialectMatmulStrategy
    • Tiles K, creates copies + allocations with tensor.pad, and distributes the copies and matmul with their own scf.forall.
    • Has trouble with consumer fusions because the thread level loops are locked inside the reduction loop.
    • Relies heavily on a number of cleanup/canonicalization patterns to work properly to get consumers to distribute properly (+ bufferization to behave well)
    • Does well with producer fusions by giving access to the full producer op when making tiling decisions. This unlinks the distribution requirements of the matmul from the distribution considerations of producers.
    • Hard coded as it is based on transform dialect scripts.

(Note this issue is excluding cooperative matrix on SPIR-V due to the specific requirements there).

All of these strategies have notable pros/cons with respect to how well they are able to handle fusions and target different accelerated instructions. The point that is the focus of this issue is the ability to do producer fusions. Out of the above list, the only approach that is able to handle producer fusions robustly is 4) because the tiling + distribution is planned completely separately from the tiling + distribution of the matmul. The shared memory allocation between the two is what bridges the difference in distribution across the workgroup. This is especially relevant when trying to do something like implicit GEMM for NCHW convolutions. The distribution of im2col needs to be decided based on the way the kernel accesses data, not how it is used in the core computation.

Proof of Concept

The following branch is a proof of concept for a way to organize convolution and matmul codegen for GPU targets based partially on 4) but with changes to fix some of the cons listed above: https://github.com/qedawkins/iree/tree/igemm. An accompanying script can be found here: https://gist.github.com/qedawkins/ee0ca928634b5533b591ce804fa5e080

The experimental strategy here also tiles the producers of the matmul on their own (in this case manually introduced copies), however instead of waiting until after bufferization and loops are distributed to "fix up" the iterator type of the reduction loop, all of those parallel loops within the body of the scf.for are fused into one, allowing further hoisting out of the loop and fusion of consumers.

#map = affine_map<(d0) -> (d0 * 2)>
#map1 = affine_map<(d0) -> (d0 * 4)>
#map2 = affine_map<(d0)[s0] -> (d0 * 4 + s0)>
#map3 = affine_map<(d0)[s0] -> (d0 * 2 + s0)>
#map4 = affine_map<(d0) -> (d0 * 16)>
module {
  func.func @main(%arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>) -> tensor<128x128xf32> {
    %c0 = arith.constant 0 : index
    %c128 = arith.constant 128 : index
    %c4 = arith.constant 4 : index
    %0 = tensor.empty() : tensor<128x4xf32>
    %1 = tensor.empty() : tensor<4x128xf32>
    %2 = scf.for %arg3 = %c0 to %c128 step %c4 iter_args(%arg4 = %arg2) -> (tensor<128x128xf32>) {
      %3 = scf.forall (%arg5, %arg6) in (64, 1) shared_outs(%arg7 = %0) -> (tensor<128x4xf32>) {
        %6 = affine.apply #map(%arg5)
        %7 = affine.apply #map1(%arg6)
        %8 = affine.apply #map(%arg5)
        %9 = affine.apply #map2(%arg6)[%arg3]
        %extracted_slice = tensor.extract_slice %arg0[%8, %9] [2, 4] [1, 1] : tensor<128x128xf32> to tensor<2x4xf32>
        %extracted_slice_0 = tensor.extract_slice %arg7[%6, %7] [2, 4] [1, 1] : tensor<128x4xf32> to tensor<2x4xf32>
        %10 = linalg.copy ins(%extracted_slice : tensor<2x4xf32>) outs(%extracted_slice_0 : tensor<2x4xf32>) -> tensor<2x4xf32>
        scf.forall.in_parallel {
          tensor.parallel_insert_slice %10 into %arg7[%6, %7] [2, 4] [1, 1] : tensor<2x4xf32> into tensor<128x4xf32>
        }
      } {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
      %4 = scf.forall (%arg5, %arg6) in (2, 32) shared_outs(%arg7 = %1) -> (tensor<4x128xf32>) {
        %6 = affine.apply #map(%arg5)
        %7 = affine.apply #map1(%arg6)
        %8 = affine.apply #map3(%arg5)[%arg3]
        %9 = affine.apply #map1(%arg6)
        %extracted_slice = tensor.extract_slice %arg1[%8, %9] [2, 4] [1, 1] : tensor<128x128xf32> to tensor<2x4xf32>
        %extracted_slice_0 = tensor.extract_slice %arg7[%6, %7] [2, 4] [1, 1] : tensor<4x128xf32> to tensor<2x4xf32>
        %10 = linalg.copy ins(%extracted_slice : tensor<2x4xf32>) outs(%extracted_slice_0 : tensor<2x4xf32>) -> tensor<2x4xf32>
        scf.forall.in_parallel {
          tensor.parallel_insert_slice %10 into %arg7[%6, %7] [2, 4] [1, 1] : tensor<2x4xf32> into tensor<4x128xf32>
        }
      } {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
      %5 = scf.forall (%arg5, %arg6) in (8, 8) shared_outs(%arg7 = %arg4) -> (tensor<128x128xf32>) {
        %6 = affine.apply #map4(%arg5)
        %7 = affine.apply #map4(%arg6)
        %extracted_slice = tensor.extract_slice %3[%6, 0] [16, 4] [1, 1] : tensor<128x4xf32> to tensor<16x4xf32>
        %extracted_slice_0 = tensor.extract_slice %4[0, %7] [4, 16] [1, 1] : tensor<4x128xf32> to tensor<4x16xf32>
        %extracted_slice_1 = tensor.extract_slice %arg7[%6, %7] [16, 16] [1, 1] : tensor<128x128xf32> to tensor<16x16xf32>
        %8 = linalg.matmul ins(%extracted_slice, %extracted_slice_0 : tensor<16x4xf32>, tensor<4x16xf32>) outs(%extracted_slice_1 : tensor<16x16xf32>) -> tensor<16x16xf32>
        scf.forall.in_parallel {
          tensor.parallel_insert_slice %8 into %arg7[%6, %7] [16, 16] [1, 1] : tensor<16x16xf32> into tensor<128x128xf32>
        }
      } {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
      scf.yield %5 : tensor<128x128xf32>
    }
    return %2 : tensor<128x128xf32>
  }
}

After fusion of parallel loops

#map = affine_map<(d0) -> (d0 * 16)>
#map1 = affine_map<(d0, d1) -> (d0 * 8 + d1)>
#map2 = affine_map<(d0) -> (d0 * 2)>
#map3 = affine_map<(d0) -> (d0 * 4)>
#map4 = affine_map<(d0)[s0] -> (d0 * 4 + s0)>
#map5 = affine_map<(d0)[s0] -> (d0 * 2 + s0)>
module {
  func.func @main(%arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>) -> tensor<128x128xf32> {
    %c32 = arith.constant 32 : index
    %c2 = arith.constant 2 : index
    %c1 = arith.constant 1 : index
    %c64 = arith.constant 64 : index
    %c0 = arith.constant 0 : index
    %c128 = arith.constant 128 : index
    %c4 = arith.constant 4 : index
    %0 = tensor.empty() : tensor<128x4xf32>
    %1 = tensor.empty() : tensor<4x128xf32>
    %2 = scf.for %arg3 = %c0 to %c128 step %c4 iter_args(%arg4 = %arg2) -> (tensor<128x128xf32>) {
      %3 = scf.forall (%arg5, %arg6) in (8, 8) shared_outs(%arg7 = %arg4) -> (tensor<128x128xf32>) {
        %4 = affine.apply #map(%arg5)
        %5 = affine.apply #map(%arg6)
        %6 = affine.apply #map1(%arg5, %arg6)
        %7:2 = affine.delinearize_index %6 into (%c64, %c1) : index, index
        %8 = affine.apply #map2(%7#0)
        %9 = affine.apply #map3(%7#1)
        %10 = affine.apply #map4(%7#1)[%arg3]
        %extracted_slice = tensor.extract_slice %arg0[%8, %10] [2, 4] [1, 1] : tensor<128x128xf32> to tensor<2x4xf32>
        %extracted_slice_0 = tensor.extract_slice %0[%8, %9] [2, 4] [1, 1] : tensor<128x4xf32> to tensor<2x4xf32>
        %11 = linalg.copy ins(%extracted_slice : tensor<2x4xf32>) outs(%extracted_slice_0 : tensor<2x4xf32>) -> tensor<2x4xf32>
        %12 = iree_gpu.shuffle_tensor %11[%8, %9] [2, 4] [1, 1] to %0 [%4, 0] [16, 4] [1, 1] : tensor<2x4xf32> -> tensor<128x4xf32> -> tensor<16x4xf32>
        %13:2 = affine.delinearize_index %6 into (%c2, %c32) : index, index
        %14 = affine.apply #map2(%13#0)
        %15 = affine.apply #map3(%13#1)
        %16 = affine.apply #map5(%13#0)[%arg3]
        %extracted_slice_1 = tensor.extract_slice %arg1[%16, %15] [2, 4] [1, 1] : tensor<128x128xf32> to tensor<2x4xf32>
        %extracted_slice_2 = tensor.extract_slice %1[%14, %15] [2, 4] [1, 1] : tensor<4x128xf32> to tensor<2x4xf32>
        %17 = linalg.copy ins(%extracted_slice_1 : tensor<2x4xf32>) outs(%extracted_slice_2 : tensor<2x4xf32>) -> tensor<2x4xf32>
        %18 = iree_gpu.shuffle_tensor %17[%14, %15] [2, 4] [1, 1] to %1 [0, %5] [4, 16] [1, 1] : tensor<2x4xf32> -> tensor<4x128xf32> -> tensor<4x16xf32>
        %extracted_slice_3 = tensor.extract_slice %arg7[%4, %5] [16, 16] [1, 1] : tensor<128x128xf32> to tensor<16x16xf32>
        %19 = linalg.matmul ins(%12, %18 : tensor<16x4xf32>, tensor<4x16xf32>) outs(%extracted_slice_3 : tensor<16x16xf32>) -> tensor<16x16xf32>
        scf.forall.in_parallel {
          tensor.parallel_insert_slice %19 into %arg7[%4, %5] [16, 16] [1, 1] : tensor<16x16xf32> into tensor<128x128xf32>
        }
      } {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
      scf.yield %3 : tensor<128x128xf32>
    }
    return %2 : tensor<128x128xf32>
  }
}

The current rough outline for the pipeline is as follows

  1. Tile + fuse to workgroups
  2. Tile the reduction loop of the matmul and greedily fuse producers of the inputs.
    • There is an argument for tiling to threads first, however in practice this does not compose well with trying to distribute producers. When tiling fused producers of the matmul, there is an implicit requirement that the distributed producer does not break up the K tile of the matmul. Or another way to put it, the data the producer has to produce is a tile for a single iteration of the matmul loop. Tiling the matmul to a reduction first avoids strange implicit links like that.
  3. Tile the inner matmul and fused producers to threads. If there is no fused producer for the matmul, we insert a copy to root the tiling on.
    • There are two possible ways to handle special matrix intrinsics at this point. For something like cooperative matrix on SPIR-V, there is no choice but to tile the matmul only to warps. For something like MFMA, WMMA, or mma.sync in LLVMGPU, we could consider adding special operations to represent them before vectorization, or use some kind of tensor encoding to abstract out the specifics of the intrinsic while tiling. The other path is to keep using vector distribution for mapping to intrinsics, but this mixing of warp and thread semantics might be difficult to manage.
  4. Fuse the parallel loops of the matmul and producers into one.
    • This will likely require a step to map the trip counts of the parallel loops to a fixed common iteration count once considering unaligned and dynamic cases. This would include handling for cases like over and underprovisioning the loop (forming an scf.for inside the body of the loop).
    • This fusion currently inserts a special op to manage the former boundary between the loops. The simplest lowering of the op is to an allocation + two linalg.copy ops to copy the slices for each worker in and out of the shared allocation, + a barrier to synchronize the workers. However the lack of support for mixed buffer-tensor semantics makes this unwieldy today. In the future we will likely want to add different lowerings for this "tensor shuffle" to allow generating something like a wait token for mapping to asynchronous copies, or to simplify the multibuffering + pipelining transformations.
  5. Hoist the inner parallel loop out of the scf.for and greedily fuse all consumers of the loop.
  6. Multibuffer + pipeline (first level of scheduling)
  7. Vectorize
  8. Bufferize
  9. Lower the scf.forall ops and late stage target specific lowerings.

With this approach, the only difference with convolution is the presence of an im2col operation introduced between steps 1 and 2. The im2col operation will just need to implement the tiling interface and have a way to decompose to linalg/vector ops once tiled to threads.

Tasks

A list of rough tasks, some of which are already done in the above or started elsewhere.

Shared

Convolution

qedawkins commented 2 months ago

I updated the branch + script to include hoisting an scf.forall out of a loop.

cc @MaheshRavishankar

MaheshRavishankar commented 2 months ago

Thanks @qedawkins . Most of this makes sense. I'll make time tomorrow to walk through more if you have the time. The only part that is a bit sketchy for me is the "fuses parallel loops". If they are really fusable, you should be able to tile + fuse. For example, your end-state code looks very similar to what you get from tile and fuse. Not to say we never need "fusion of loops", but more that we havent needed it so far....

qedawkins commented 2 months ago

Here is a dump-after-all with the current branch + spec above: https://gist.github.com/qedawkins/953b4e9da86ad48c94b978323f2b39ae. The key IR we're trying to get to is the following

func.func @main() {
  %c32 = arith.constant 32 : index
  %c2 = arith.constant 2 : index
  %c1 = arith.constant 1 : index
  %c64 = arith.constant 64 : index
  %c4 = arith.constant 4 : index
  %c128 = arith.constant 128 : index
  %c0 = arith.constant 0 : index
  %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<128x128xf32>>
  %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<128x128xf32>>
  %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readwrite:tensor<128x128xf32>>
  %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [128, 128], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<128x128xf32>> -> tensor<128x128xf32>
  %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [128, 128], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<128x128xf32>> -> tensor<128x128xf32>
  %5 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [128, 128], strides = [1, 1] : !flow.dispatch.tensor<readwrite:tensor<128x128xf32>> -> tensor<128x128xf32>
  %6 = tensor.empty() : tensor<128x4xf32>
  %7 = tensor.empty() : tensor<4x128xf32>
  %8 = scf.forall (%arg0, %arg1) in (8, 8) shared_outs(%arg2 = %5) -> (tensor<128x128xf32>) {
    %9 = affine.apply affine_map<(d0) -> (d0 * 16)>(%arg0)
    %10 = affine.apply affine_map<(d0) -> (d0 * 16)>(%arg1)
    %extracted_slice = tensor.extract_slice %arg2[%9, %10] [16, 16] [1, 1] : tensor<128x128xf32> to tensor<16x16xf32>
    %11 = affine.apply affine_map<(d0, d1) -> (d0 * 8 + d1)>(%arg0, %arg1)
    %12:2 = affine.delinearize_index %11 into (%c64, %c1) : index, index
    %13 = affine.apply affine_map<(d0) -> (d0 * 2)>(%12#0)
    %14 = affine.apply affine_map<(d0) -> (d0 * 4)>(%12#1)
    %extracted_slice_0 = tensor.extract_slice %6[%13, %14] [2, 4] [1, 1] : tensor<128x4xf32> to tensor<2x4xf32>
    %15:2 = affine.delinearize_index %11 into (%c2, %c32) : index, index
    %16 = affine.apply affine_map<(d0) -> (d0 * 2)>(%15#0)
    %17 = affine.apply affine_map<(d0) -> (d0 * 4)>(%15#1)
    %extracted_slice_1 = tensor.extract_slice %7[%16, %17] [2, 4] [1, 1] : tensor<4x128xf32> to tensor<2x4xf32>
    %18 = scf.for %arg3 = %c0 to %c128 step %c4 iter_args(%arg4 = %extracted_slice) -> (tensor<16x16xf32>) {
      %19 = affine.apply affine_map<(d0)[s0] -> (d0 * 4 + s0)>(%12#1)[%arg3]
      %extracted_slice_2 = tensor.extract_slice %3[%13, %19] [2, 4] [1, 1] : tensor<128x128xf32> to tensor<2x4xf32>
      %20 = linalg.copy ins(%extracted_slice_2 : tensor<2x4xf32>) outs(%extracted_slice_0 : tensor<2x4xf32>) -> tensor<2x4xf32>
      %21 = iree_gpu.shuffle_tensor %20[%13, %14] [2, 4] [1, 1] to %6 [%9, 0] [16, 4] [1, 1] : tensor<2x4xf32> -> tensor<128x4xf32> -> tensor<16x4xf32>
      %22 = affine.apply affine_map<(d0)[s0] -> (d0 * 2 + s0)>(%15#0)[%arg3]
      %extracted_slice_3 = tensor.extract_slice %4[%22, %17] [2, 4] [1, 1] : tensor<128x128xf32> to tensor<2x4xf32>
      %23 = linalg.copy ins(%extracted_slice_3 : tensor<2x4xf32>) outs(%extracted_slice_1 : tensor<2x4xf32>) -> tensor<2x4xf32>
      %24 = iree_gpu.shuffle_tensor %23[%16, %17] [2, 4] [1, 1] to %7 [0, %10] [4, 16] [1, 1] : tensor<2x4xf32> -> tensor<4x128xf32> -> tensor<4x16xf32>
      %25 = linalg.matmul ins(%21, %24 : tensor<16x4xf32>, tensor<4x16xf32>) outs(%arg4 : tensor<16x16xf32>) -> tensor<16x16xf32>
      scf.yield %25 : tensor<16x16xf32>
    }    
    scf.forall.in_parallel {
      tensor.parallel_insert_slice %18 into %arg2[%9, %10] [16, 16] [1, 1] : tensor<16x16xf32> into tensor<128x128xf32>
    }    
  } {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
  flow.dispatch.tensor.store %8, %2, offsets = [0, 0], sizes = [128, 128], strides = [1, 1] : tensor<128x128xf32> -> !flow.dispatch.tensor<readwrite:tensor<128x128xf32>>
  return
}
qedawkins commented 2 months ago

I updated the branch to enable generating MFMA ops as well with this spec: https://gist.github.com/qedawkins/334c6bce944c6b860066ca873e1388d2

I'm going to start landing some of the transform ops used in the above spec.