iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.85k stars 620 forks source link

[DT] Next steps for data tiling fusion improvements #18513

Open Max191 opened 2 months ago

Max191 commented 2 months ago

Overview

This issue is a tracker for the next steps of enabling the data tiling fusion improvements, which are currently behind the --iree-dispatch-creation-experimental-data-tiling flag. Once the steps here are complete, there should be fairly reasonable support for general grouped quantized matmul operations through data tiling.

The main steps are as follows:

  1. Allow SetEncoding to set encodings on contraction ops that have greater than one batch, M, N, or K dimension.
  2. Support multi batch/M/N/K contractions in MaterializeEncoding.
  3. Make a lowering path for new data tiled multi batch/M/N/K contractions.
  4. Add a new matching pattern for UKernels to lower new data tiled multi batch/M/N/K to UKernel ops.

1. SetEncoding

This change should be very simple. SetEncoding should not behave any differently with multiple batch, M, N, or K dimensions. We just need to remove the restriction in the pass for single batch/M/N/K dimensions.

2. MaterializeEncoding

This is a more involved step, since materializing a multi batch/M/N/K dimension contraction cannot create an MMT4D op. Instead, we need to create a linalg.generic op with the appropriate iterator_types and indexing_maps. There are useful functions for checking if generic ops are contractions, and getting the appropriate dimension mapping to batch, M, N, and K dimensions here:

https://github.com/llvm/llvm-project/blob/030c6da7af826b641db005be925b20f956c3a6bb/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h#L50-L71

Here is an example of how this is used to find the narrow M and N sizes of a matmul:

https://github.com/iree-org/iree/blob/febe0edba1ea4b29135d23d64d84417b7a106be8/compiler/src/iree/compiler/DispatchCreation/SetEncoding.cpp#L56-L67

I have written some IR for what I would expect the final transformation to look like below. Note that the only materialization pattern that should be needed is the one for the final contraction operation. The patterns for set_encoding, unset_encoding, and dequant ops should already be supported. I only included them in this example so it is clear where this all comes from.

module {
  func.func @grouped_quantized_matmul(%lhs: tensor<1024x32x128xf32>, %rhs_quant: tensor<11008x32x128xi4>,
                                      %scales: tensor<11008x32xf32>, %zps: tensor<11008x32xf32>) -> tensor<1024x11008xf32> {
    %cst = arith.constant 0.000000e+00 : f32
    %dq_init = tensor.empty() : tensor<11008x32x128xf32>
    %rhs = linalg.generic {
        indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
                         affine_map<(d0, d1, d2) -> (d0, d1)>,
                         affine_map<(d0, d1, d2) -> (d0, d1)>,
                         affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
        iterator_types = ["parallel", "parallel", "parallel"]}
        ins(%rhs_quant, %scales, %zps : tensor<11008x32x128xi4>, tensor<11008x32xf32>, tensor<11008x32xf32>) outs(%dq_init : tensor<11008x32x128xf32>) {
    ^bb0(%in: i4, %in_0: f32, %in_1: f32, %out: f32):
      %5 = arith.extui %in : i4 to i32
      %6 = arith.uitofp %5 : i32 to f32
      %7 = arith.subf %6, %in_1 : f32
      %8 = arith.mulf %7, %in_0 : f32
      linalg.yield %8 : f32
    } -> tensor<11008x32x128xf32>
    %empty = tensor.empty() : tensor<1024x11008xf32>
    %mm_init = linalg.fill ins(%cst : f32) outs(%empty : tensor<1024x11008xf32>) -> tensor<1024x11008xf32>
    %mm = linalg.generic {
        indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
                         affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>,
                         affine_map<(d0, d1, d2, d3) -> (d0, d1)>],
        iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
        ins(%lhs, %rhs : tensor<1024x32x128xf32>, tensor<11008x32x128xf32>) outs(%mm_init : tensor<1024x11008xf32>) {
    ^bb0(%in: f32, %in_0: f32, %out: f32):
      %5 = arith.mulf %in, %in_0 : f32
      %6 = arith.addf %5, %out : f32
      linalg.yield %6 : f32
    } -> tensor<1024x11008xf32>
    return %mm : tensor<1024x11008xf32>
  }

  // ======================================================================== //
  // ======= @grouped_quantized_matmul after materializing encodings: ======= //
  // ======================================================================== //

  func.func @packed_grouped_quantized_matmul(%lhs: tensor<1024x32x128xf32>, %rhs_quant: tensor<11008x32x128xi4>,
                                      %scales: tensor<11008x32xf32>, %zps: tensor<11008x32xf32>) -> tensor<1024x11008xf32> {
    %cst = arith.constant 0.000000e+00 : f32
    %cst_i4 = arith.constant 0 : i4
    // LHS pack ( [M x K0 x K1] -> [M x K0 x K1 x m x k1] )
    %pack_lhs_init = tensor.empty() : tensor<256x32x8x4x16xf32>
    %pack_lhs = tensor.pack %lhs padding_value(%cst : f32)
                                 outer_dims_perm = [0, 1, 2]
                                 inner_dims_pos = [0, 2]
                                 inner_tiles = [4, 16] into %pack_lhs_init
        : tensor<1024x32x128xf32> -> tensor<256x32x8x4x16xf32>
    // Scales pack ( [N x K0] -> [N x K0 x n] )
    %pack_scales_init = tensor.empty() : tensor<1376x32x8xf32>
    %pack_scales = tensor.pack %scales padding_value(%cst : f32)
                                 outer_dims_perm = [0, 1]
                                 inner_dims_pos = [0]
                                 inner_tiles = [8] into %pack_scales_init
        : tensor<11008x32xf32> -> tensor<1376x32x8xf32>
    // Zero points pack ( [N x K0] -> [N x K0 x n] )
    %pack_zps_init = tensor.empty() : tensor<1376x32x8xf32>
    %pack_zps = tensor.pack %zps padding_value(%cst : f32)
                                 outer_dims_perm = [0, 1]
                                 inner_dims_pos = [0]
                                 inner_tiles = [8] into %pack_zps_init
        : tensor<11008x32xf32> -> tensor<1376x32x8xf32>
    // RHS quant pack ( [N x K0 x K1] -> [N x K0 x K1 x n x k1] )
    %pack_rhs_init = tensor.empty() : tensor<1376x32x8x8x16xi4>
    %pack_rhs_quant = tensor.pack %rhs_quant padding_value(%cst_i4 : i4)
                                 outer_dims_perm = [0, 1, 2]
                                 inner_dims_pos = [0, 2]
                                 inner_tiles = [8, 16] into %pack_rhs_init
        : tensor<11008x32x128xi4> -> tensor<1376x32x8x8x16xi4>
    %dq_init = tensor.empty() : tensor<1376x32x8x8x16xf32>
    %pack_rhs = linalg.generic {
        indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>,
                         affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>,
                         affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>,
                         affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>],
        iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]}
        ins(%pack_rhs_quant, %pack_scales, %pack_zps : tensor<1376x32x8x8x16xi4>, tensor<1376x32x8xf32>, tensor<1376x32x8xf32>)
        outs(%dq_init : tensor<1376x32x8x8x16xf32>) {
    ^bb0(%in: i4, %in_0: f32, %in_1: f32, %out: f32):
      %5 = arith.extui %in : i4 to i32
      %6 = arith.uitofp %5 : i32 to f32
      %7 = arith.subf %6, %in_1 : f32
      %8 = arith.mulf %7, %in_0 : f32
      linalg.yield %8 : f32
    } -> tensor<1376x32x8x8x16xf32>
    // Matmul init ( [M x N x m x n] )
    %empty = tensor.empty() : tensor<256x1376x4x8xf32>
    %mm_init = linalg.fill ins(%cst : f32) outs(%empty : tensor<256x1376x4x8xf32>) -> tensor<256x1376x4x8xf32>
    %mm = linalg.generic {
        indexing_maps = [affine_map<(M, N, K0, K1, m, n, k1) -> (M, K0, K1, m, k1)>,
                         affine_map<(M, N, K0, K1, m, n, k1) -> (N, K0, K1, n, k1)>,
                         affine_map<(M, N, K0, K1, m, n, k1) -> (M, N, m, n)>],
        iterator_types = ["parallel", "parallel", "reduction", "reduction", "parallel", "parallel", "reduction"]}
        ins(%pack_lhs, %pack_rhs : tensor<256x32x8x4x16xf32>, tensor<1376x32x8x8x16xf32>) outs(%mm_init : tensor<256x1376x4x8xf32>) {
    ^bb0(%in: f32, %in_0: f32, %out: f32):
      %5 = arith.mulf %in, %in_0 : f32
      %6 = arith.addf %5, %out : f32
      linalg.yield %6 : f32
    } -> tensor<256x1376x4x8xf32>
    %unpack_init = tensor.empty() : tensor<1024x11008xf32>
    %unpack = tensor.unpack %mm outer_dims_perm = [0, 1]
                                inner_dims_pos = [0, 1]
                                inner_tiles = [4, 8] into %unpack_init
        : tensor<256x1376x4x8xf32> -> tensor<1024x11008xf32>
    return %unpack : tensor<1024x11008xf32>
  }
}

3. Lowering Path

The next thing we need is a lowering path for the resulting data tiled matmul:

%mm = linalg.generic {
    indexing_maps = [affine_map<(M, N, K0, K1, m, n, k1) -> (M, K0, K1, m, k1)>,
                     affine_map<(M, N, K0, K1, m, n, k1) -> (N, K0, K1, n, k1)>,
                     affine_map<(M, N, K0, K1, m, n, k1) -> (M, N, m, n)>],
    iterator_types = ["parallel", "parallel", "reduction", "reduction", "parallel", "parallel", "reduction"]}
    ins(%pack_lhs, %pack_rhs : tensor<256x32x8x4x16xf32>, tensor<1376x32x8x8x16xf32>) outs(%mm_init : tensor<256x1376x4x8xf32>) {
^bb0(%in: f32, %in_0: f32, %out: f32):
  %5 = arith.mulf %in, %in_0 : f32
  %6 = arith.addf %5, %out : f32
  linalg.yield %6 : f32
} -> tensor<256x1376x4x8xf32>

This primarily means setting an appropriate lowering configuration for the op, but there will likely be issues to solve in whatever lowering pipeline this goes down. There are a few changes to be made for lowering configurations:

  1. Right now, generic op contractions will get a default GenericOp configuration: https://github.com/iree-org/iree/blob/febe0edba1ea4b29135d23d64d84417b7a106be8/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp#L2245-L2250 Even when the contraction is a generic op, it should be handled as a contraction op. This means we need some new logic to handle linalg ops for which linalg::isaContractionOpInterface returns true. There is something similar done in GPU configurations: https://github.com/iree-org/iree/blob/febe0edba1ea4b29135d23d64d84417b7a106be8/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp#L854-L864 In general, I think contraction ops with multiple M, N or K dimensions should have all outer M/N/K dimensions tiled to 1 as an initial strategy.
  2. Once there is some logic in place to set configurations based on contraction dimensions, we should refactor the existing matmul-like op config logic to use the new logic. This is more of a code clean-up task, so everything is in one place.
  3. The pipeline to use will be the new TileRootFuseConsumerProducer pipeline. There may be some issues along the way to solve as more dispatches are tested on the new pipeline.

Lower to UKernels

UKernels are rooted on MMT4D ops right now, and we don't want to lose the benefits of ukernels when we materialized generic ops instead of mmt4d. The CPULowerToUKernels pass converts a set of named operations into UKernel ops based on some matching patterns called here: https://github.com/iree-org/iree/blob/febe0edba1ea4b29135d23d64d84417b7a106be8/compiler/src/iree/compiler/Codegen/Common/CPU/CPULowerToUKernels.cpp#L584-L585 A new matchDAGForUKernel matcher will be needed for linalg.generic ops. The new matcher will need to check the shapes of the generic op inputs against the supported ukernel shapes, since it is not guaranteed that an arbitrary generic op will match the shape of a ukernel. Would be good to get @bjacob's input on this, since matching against all possible ukernel shapes seems a little cumbersome.

EDIT: See Benoit's comment below. There should be no need to match against specific shapes for now. Worst case will be bad performance, but that can be something to look into later.

Another detail is that we will most likely want to do some unit dim folding after tiling to get rid of outer unit dimensions on the contractions. This will make matching ukernels much more simple.

Final Comments

The sequence of steps in the overview are listed in some order, but that is not a strict dependency order. Particularly, steps 2. and 3. could be done in either order, since I have provided some example IR after materialization. Doing 3. first could provide some lessons for the particular form we want as a result of 2. in general. Although 3. is dependent on 2. for e2e enablement, they could be worked on in parallel.

bjacob commented 2 months ago

Regarding matching generic to mmt4d ukernels:

This can be viewed as 2 separate steps:

  1. Match linalg.generic to linalg.mmt4d. This is the new code that needs to be written. It can be a separate pass or pattern, independent of ukernels.
  2. From there, the existing CPULowerToUkernels pass can run, unmodified.

So all you need to do here is that step 1.

There is a caveat about what "support" means. The mmt4d ukernel supports arbitrary linalg.mmt4d ops shapes, for the supported combinations of element types, in the sense that the lowering always succeeds and there is always an actual ukernel to handle any case. The caveat is about performance: that wide support is achieved thanks to a slow fallback (which is slower than any generic codegen), so there is a performance cliff here, that one can silently fall off of.

The worst-case scenario then is that implementing the linalg.generic to linalg.mmt4d rewrite causes a workload to hit the slow ukernel path. However, that problem preexists with workloads that use linalg.mmt4d ops. The reason why we haven't prioritized solving it is a combination of: it's not been such a big deal (because we have added ukernel fast paths as needed) and it's not trivial to fix because this requires enlarging the contact surface between ukernels and the compiler. However, if we are willing to do that, then the solution is a simple enough PR: https://github.com/iree-org/iree/pull/16880 . GitHub shows an outdated diff base for that PR, but most of that diff is already in main as https://github.com/iree-org/iree/pull/16879 was merged: we have already structured the ukernels code to have a declarative table enumerating fast code paths, so what is left to do in that PR is a simple matter of exposing that to the compiler.

Max191 commented 2 months ago

There is a caveat about what "support" means. The mmt4d ukernel supports arbitrary linalg.mmt4d ops shapes, for the supported combinations of element types, in the sense that the lowering always succeeds and there is always an actual ukernel to handle any case. The caveat is about performance: that wide support is achieved thanks to a slow fallback (which is slower than any generic codegen), so there is a performance cliff here, that one can silently fall off of.

I see, thanks for the context. I think it makes sense to start by unconditionally converting linalg.generic to mmt4d. It would be best to complicate the current next steps as little as we can. We can leave it as a potential TODO in case we hit any future performance pitfalls in the wild.