iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.83k stars 611 forks source link

[DT][CPU] Improve codegen pipeline to handle mmt4d fusion #16025

Open hanhanW opened 10 months ago

hanhanW commented 10 months ago

We have opportunities to fuse mmt4d op with its consumers, but there are performance issues. This happens if we simplifies 1D pack/unpack to expand_shape/collapse_shape ops. Generally it is good because they become metadata op. This issue describes what's happening in mmt4d fusion. Below is the snippet from GPT2 model.

func.func @mmt4d_fusion(%arg0: tensor<1x768x1x1xf32>, %arg1: tensor<192x768x16x1xf32>, %arg2: tensor<1x192x1x16xf32>) -> tensor<1x192x1x16xf32> {
  %cst = arith.constant 0.000000e+00 : f32
  %0 = tensor.empty() : tensor<1x192x1x16xf32>
  %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<1x192x1x16xf32>) -> tensor<1x192x1x16xf32>
  %2 = linalg.mmt4d ins(%arg0, %arg1 : tensor<1x768x1x1xf32>, tensor<192x768x16x1xf32>) outs(%1 : tensor<1x192x1x16xf32>) -> tensor<1x192x1x16xf32>
  %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%2, %arg2 : tensor<1x192x1x16xf32>, tensor<1x192x1x16xf32>) outs(%0 : tensor<1x192x1x16xf32>) {
  ^bb0(%in: f32, %in_0: f32, %out: f32):
    %4 = arith.addf %in, %in_0 : f32
    linalg.yield %4 : f32
  } -> tensor<1x192x1x16xf32>
  return %3 : tensor<1x192x1x16xf32>
}

There are couple issues in current pipeline:

  1. The Mmt4dTilingExpert pipeline is not working well with current TilingConfig. It only considers three tiling levels. We should make it more compatible with TilingConfig.
  2. The codegen assumes that only leading parallel dims are shared. This is not the case in the example because the result indexing_map (in mmt4d op) is (d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4) while the indexing_map (in generic op) is identical (i.e., (d0, d1, d2, d3) -> (d0, d1, d2, d3)). This leads a bug in multi lowering_config. The mapping is not taken into account in the method.
  3. LLVMCPUTileAndFuse should ignore iree_codegen.ukernel.generic op. The iteration domain information is gone after converting a linalg op to ukernel.generic op. Thus, we are not able to tile and fuse the ukernel op.
  4. If we want to enable ukernel in fusion cases, we need to revisit when to convert the mmt4d op to ukernel op. It can either happen after distribution or first level of TileAndFuse. The former one could introduce a big stack buffer because the parallel dimensions are only tiled for distribution. The latter one needs more investigation.

(The other option is to consider specialization. If this is fusion, we go with codegen. Otherwise, we go with ukernel path. I haven't explored this path, so no further comments.)

hanhanW commented 10 months ago

The performance issue is come from bad configurations on generic op. It is generating scalar codes. One of potential solutions is to set vector level tile sizes to zeros when ukernel is enabled; let setLoweringConfigForComputeOps set the vector tile sizes for generic ops.

MaheshRavishankar commented 10 months ago

Thanks Hanhan, I am digesting this slowly.

  1. The codegen assumes that only leading parallel dims are shared. This is not the case in the example because the result indexing_map (in mmt4d op) is (d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4) while the indexing_map (in generic op) is identical (i.e., (d0, d1, d2, d3) -> (d0, d1, d2, d3)). This leads a bug in multi lowering_config. The mapping is not taken into account in the method.

Ok, this does need better matching of producer iteration space to consumer iteration space. This is a problem with multi lowering config... Something to just note.

  1. LLVMCPUTileAndFuse should ignore iree_codegen.ukernel.generic op. The iteration domain information is gone after converting a linalg op to ukernel.generic op. Thus, we are not able to tile and fuse the ukernel op.
  2. If we want to enable ukernel in fusion cases, we need to revisit when to convert the mmt4d op to ukernel op. It can either happen after distribution or first level of TileAndFuse. The former one could introduce a big stack buffer because the parallel dimensions are only tiled for distribution. The latter one needs more investigation.

I think ukernel lowering should happen after all tiling is done. It basically on par with vectorization stage.

hanhanW commented 10 months ago

With recent fixes, I'm able to codegen mmt4d fusion with and without ukernels. The lowering_config is not correct because the lowering_config is not set correctly on generic ops. The consumer becomes scalars.

    1606:   41 0f 18 8a 14 02 00    prefetcht0 BYTE PTR [r10+0x214]
    160d:   00 
    160e:   49 83 c1 60             add    r9,0x60
    1612:   49 83 c2 18             add    r10,0x18
    1616:   49 81 f9 00 60 00 00    cmp    r9,0x6000
    161d:   0f 85 2d ff ff ff       jne    1550 <iree_hal_executable_library_query-0x240>
    1623:   62 d1 7c 48 11 40 01    vmovups ZMMWORD PTR [r8+0x40],zmm0
    162a:   49 83 e8 80             sub    r8,0xffffffffffffff80
    162e:   48 81 c6 00 80 01 00    add    rsi,0x18000
    1635:   83 c7 02                add    edi,0x2
    1638:   81 ff c0 00 00 00       cmp    edi,0xc0
    163e:   0f 85 fc fd ff ff       jne    1440 <iree_hal_executable_library_query-0x350>
    1644:   ba 3c 00 00 00          mov    edx,0x3c
    1649:   0f 1f 80 00 00 00 00    nop    DWORD PTR [rax+0x0]
    1650:   c5 fa 10 44 14 c4       vmovss xmm0,DWORD PTR [rsp+rdx*1-0x3c]
    1656:   c5 fa 58 44 10 c4       vaddss xmm0,xmm0,DWORD PTR [rax+rdx*1-0x3c]
    165c:   c5 fa 11 44 11 c4       vmovss DWORD PTR [rcx+rdx*1-0x3c],xmm0
    1662:   c5 fa 10 44 14 c8       vmovss xmm0,DWORD PTR [rsp+rdx*1-0x38]
    1668:   c5 fa 58 44 10 c8       vaddss xmm0,xmm0,DWORD PTR [rax+rdx*1-0x38]
    166e:   c5 fa 11 44 11 c8       vmovss DWORD PTR [rcx+rdx*1-0x38],xmm0
    1674:   c5 fa 10 44 14 cc       vmovss xmm0,DWORD PTR [rsp+rdx*1-0x34]
    167a:   c5 fa 58 44 10 cc       vaddss xmm0,xmm0,DWORD PTR [rax+rdx*1-0x34]
    1680:   c5 fa 11 44 11 cc       vmovss DWORD PTR [rcx+rdx*1-0x34],xmm0
    1686:   c5 fa 10 44 14 d0       vmovss xmm0,DWORD PTR [rsp+rdx*1-0x30]
    168c:   c5 fa 58 44 10 d0       vaddss xmm0,xmm0,DWORD PTR [rax+rdx*1-0x30]
    1692:   c5 fa 11 44 11 d0       vmovss DWORD PTR [rcx+rdx*1-0x30],xmm0
    1698:   c5 fa 10 44 14 d4       vmovss xmm0,DWORD PTR [rsp+rdx*1-0x2c]
    169e:   c5 fa 58 44 10 d4       vaddss xmm0,xmm0,DWORD PTR [rax+rdx*1-0x2c]
    16a4:   c5 fa 11 44 11 d4       vmovss DWORD PTR [rcx+rdx*1-0x2c],xmm0
    16aa:   c5 fa 10 44 14 d8       vmovss xmm0,DWORD PTR [rsp+rdx*1-0x28]
    16b0:   c5 fa 58 44 10 d8       vaddss xmm0,xmm0,DWORD PTR [rax+rdx*1-0x28]
    16b6:   c5 fa 11 44 11 d8       vmovss DWORD PTR [rcx+rdx*1-0x28],xmm0
    16bc:   c5 fa 10 44 14 dc       vmovss xmm0,DWORD PTR [rsp+rdx*1-0x24]
    16c2:   c5 fa 58 44 10 dc       vaddss xmm0,xmm0,DWORD PTR [rax+rdx*1-0x24]
    16c8:   c5 fa 11 44 11 dc       vmovss DWORD PTR [rcx+rdx*1-0x24],xmm0
    16ce:   c5 fa 10 44 14 e0       vmovss xmm0,DWORD PTR [rsp+rdx*1-0x20]
    16d4:   c5 fa 58 44 10 e0       vaddss xmm0,xmm0,DWORD PTR [rax+rdx*1-0x20]
    16da:   c5 fa 11 44 11 e0       vmovss DWORD PTR [rcx+rdx*1-0x20],xmm0
    16e0:   c5 fa 10 44 14 e4       vmovss xmm0,DWORD PTR [rsp+rdx*1-0x1c]
    16e6:   c5 fa 58 44 10 e4       vaddss xmm0,xmm0,DWORD PTR [rax+rdx*1-0x1c]
    16ec:   c5 fa 11 44 11 e4       vmovss DWORD PTR [rcx+rdx*1-0x1c],xmm0
    16f2:   c5 fa 10 44 14 e8       vmovss xmm0,DWORD PTR [rsp+rdx*1-0x18]
    16f8:   c5 fa 58 44 10 e8       vaddss xmm0,xmm0,DWORD PTR [rax+rdx*1-0x18]
    16fe:   c5 fa 11 44 11 e8       vmovss DWORD PTR [rcx+rdx*1-0x18],xmm0
    1704:   c5 fa 10 44 14 ec       vmovss xmm0,DWORD PTR [rsp+rdx*1-0x14]
    170a:   c5 fa 58 44 10 ec       vaddss xmm0,xmm0,DWORD PTR [rax+rdx*1-0x14]
    1710:   c5 fa 11 44 11 ec       vmovss DWORD PTR [rcx+rdx*1-0x14],xmm0
    1716:   c5 fa 10 44 14 f0       vmovss xmm0,DWORD PTR [rsp+rdx*1-0x10]
    171c:   c5 fa 58 44 10 f0       vaddss xmm0,xmm0,DWORD PTR [rax+rdx*1-0x10]
    1722:   c5 fa 11 44 11 f0       vmovss DWORD PTR [rcx+rdx*1-0x10],xmm0
    1728:   c5 fa 10 44 14 f4       vmovss xmm0,DWORD PTR [rsp+rdx*1-0xc]
    172e:   c5 fa 58 44 10 f4       vaddss xmm0,xmm0,DWORD PTR [rax+rdx*1-0xc]
    1734:   c5 fa 11 44 11 f4       vmovss DWORD PTR [rcx+rdx*1-0xc],xmm0
    173a:   c5 fa 10 44 14 f8       vmovss xmm0,DWORD PTR [rsp+rdx*1-0x8]
    1740:   c5 fa 58 44 10 f8       vaddss xmm0,xmm0,DWORD PTR [rax+rdx*1-0x8]
    1746:   c5 fa 11 44 11 f8       vmovss DWORD PTR [rcx+rdx*1-0x8],xmm0
    174c:   c5 fa 10 44 14 fc       vmovss xmm0,DWORD PTR [rsp+rdx*1-0x4]
    1752:   c5 fa 58 44 10 fc       vaddss xmm0,xmm0,DWORD PTR [rax+rdx*1-0x4]
    1758:   c5 fa 11 44 11 fc       vmovss DWORD PTR [rcx+rdx*1-0x4],xmm0
    175e:   c5 fa 10 04 14          vmovss xmm0,DWORD PTR [rsp+rdx*1]
    1763:   c5 fa 58 04 10          vaddss xmm0,xmm0,DWORD PTR [rax+rdx*1]
    1768:   c5 fa 11 04 11          vmovss DWORD PTR [rcx+rdx*1],xmm0
    176d:   48 83 c2 40             add    rdx,0x40

Here is the IR after vectorization and bufferization:

func.func @mmt4d_fusion_dispatch_0_mmt4d_1x192x768x1x16x1_f32() {
  %c0 = arith.constant 0 : index
  %c192 = arith.constant 192 : index
  %c1025_i32 = arith.constant 1025 : i32
  %c1_i32 = arith.constant 1 : i32
  %c16_i32 = arith.constant 16 : i32
  %c768 = arith.constant 768 : index
  %c2 = arith.constant 2 : index
  %c1 = arith.constant 1 : index
  %c16 = arith.constant 16 : index
  %cst = arith.constant 0.000000e+00 : f32
  %alloca = memref.alloca() {alignment = 64 : i64} : memref<1x2x1x16xf32>
  %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<1x768x1x1xf32, #hal.descriptor_type<storage_buffer>>
  memref.assume_alignment %0, 64 : memref<1x768x1x1xf32, #hal.descriptor_type<storage_buffer>>
  %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<192x768x16x1xf32, #hal.descriptor_type<storage_buffer>>
  memref.assume_alignment %1, 64 : memref<192x768x16x1xf32, #hal.descriptor_type<storage_buffer>>
  %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<1x192x1x16xf32, #hal.descriptor_type<storage_buffer>>
  memref.assume_alignment %2, 64 : memref<1x192x1x16xf32, #hal.descriptor_type<storage_buffer>>
  %3 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) alignment(64) offset(%c0) : memref<1x192x1x16xf32, #hal.descriptor_type<storage_buffer>>
  memref.assume_alignment %3, 64 : memref<1x192x1x16xf32, #hal.descriptor_type<storage_buffer>>
  %workgroup_id_x = hal.interface.workgroup.id[0] : index
  %workgroup_count_x = hal.interface.workgroup.count[0] : index
  %4 = affine.apply affine_map<()[s0] -> (s0 * 2)>()[%workgroup_id_x]
  %5 = affine.apply affine_map<()[s0] -> (s0 * 2)>()[%workgroup_count_x]
  scf.for %arg0 = %4 to %c192 step %5 {
    %subview = memref.subview %3[0, %arg0, 0, 0] [1, 2, 1, 16] [1, 1, 1, 1] : memref<1x192x1x16xf32, #hal.descriptor_type<storage_buffer>> to memref<1x2x1x16xf32, strided<[3072, 16, 16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
    %subview_0 = memref.subview %1[%arg0, 0, 0, 0] [2, 768, 16, 1] [1, 1, 1, 1] : memref<192x768x16x1xf32, #hal.descriptor_type<storage_buffer>> to memref<2x768x16x1xf32, strided<[12288, 16, 1, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
    iree_codegen.ukernel.generic "iree_uk_mmt4d" ins(%0, %subview_0 : memref<1x768x1x1xf32, #hal.descriptor_type<storage_buffer>>, memref<2x768x16x1xf32, strided<[12288, 16, 1, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) outs(%alloca : memref<1x2x1x16xf32>) (%c1, %c2, %c768, %c1_i32, %c16_i32, %c1_i32, %c1025_i32 : index, index, index, i32, i32, i32, i32) fn_def_attrs {hal.import.bitcode = true, hal.import.cconv = 1 : i32, hal.import.fields = ["processor_data"]} strided_outer_dims(1)
    %subview_1 = memref.subview %2[0, %arg0, 0, 0] [1, 2, 1, 16] [1, 1, 1, 1] : memref<1x192x1x16xf32, #hal.descriptor_type<storage_buffer>> to memref<1x2x1x16xf32, strided<[3072, 16, 16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
    scf.for %arg1 = %c0 to %c2 step %c1 {
      scf.for %arg2 = %c0 to %c16 step %c1 {
        %6 = vector.transfer_read %alloca[%c0, %arg1, %c0, %arg2], %cst {in_bounds = [true, true, true, true]} : memref<1x2x1x16xf32>, vector<1x1x1x1xf32>
        %7 = vector.transfer_read %subview_1[%c0, %arg1, %c0, %arg2], %cst {in_bounds = [true, true, true, true]} : memref<1x2x1x16xf32, strided<[3072, 16, 16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1x1x1x1xf32>
        %8 = arith.addf %6, %7 : vector<1x1x1x1xf32>
        vector.transfer_write %8, %subview[%c0, %arg1, %c0, %arg2] {in_bounds = [true, true, true, true]} : vector<1x1x1x1xf32>, memref<1x2x1x16xf32, strided<[3072, 16, 16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
      }
    }
  }
  return
}

Few issues here:

  1. We allocate a 1x2x1x16xf32 stack buffer. The outer dims are limited by distribution tile sizes. This is the tradeoff for mixed codegen and ukernels. It will be bounded by smaller sizes if we kick in cache level tiling. Furthermore, it won't be an issue if the pure codegen approach is applied.
  2. We need to set generic op lowering_config to [1, 1, 1, 16]. What we need is
    • Better setup for multi-lowering config.
    • Let generic op decide vector level tile sizes if ukernels are involved.
  3. We need to make the vector lowering up to date in Mmt4dTilingExpert.

@pzread and I are working on a draft for (2). And I will work on (3), which should refresh LLVMCPUMmt4dVectorLowering pass and leverage some features to LLVMCPUVectorLowering pass.

MaheshRavishankar commented 10 months ago

Ok, this looks good... Thanks for pushing on this.

Currently you are looking at cases where the pack becomes a reshape. We can also try for cases where pack is not changed to a reshape. All you would need to do is to have a pass that propagates the unpack after the elementwise operation. We dont need to do massive graph propagation, to start with. We can just propagate it past one elementwise operation. Then the elementwise operation can fuse with the mmt4d (both being in packed layout). But everything you listed above is a pre-cursor to having that working.

hanhanW commented 4 months ago

Prioritize the issue and flesh out more details. The issue depends on https://github.com/iree-org/iree/issues/17718. We will need to support the codegen in the new data-tiling pipeline, because it enables fusion. Today, we have tile root op and fuse consumers and producers TilingInterface API. We no longer need lowering_config propagation for mmt4d fusion. Instead, we can create the TileAndFuse v2 pass. In the new pass, we don't use the last operation to start with tiling. We can directly tile on root op and fuse both producers and consumers into the scf loops.

We don't need to wait for the parent issue. We can start some work using the the above example. (need some hack) I can provide the some codegen examples when someone picks this up.

hanhanW commented 4 months ago

@pashu123 here is an example that you can start with: https://gist.github.com/hanhanW/075f4881664ce095d4af49a29842a6ba#file-z-mlir

You can follow https://github.com/iree-org/iree/commit/3d23684f2510264e2793aeb49e53269fd168e4f3 to create a new strategy and add pipeline tests like this. In the example, presets the lowering_config and strategy and run the LLVMGPUTileAndFuse lowering pipeline. You should be able to do the same thing in your prototype.

Also, please help create other test cases like broadcast + batch_mmt4d -> elementwise and other dequant ops. That will be the input of the new pipeline.

===

side note: to enable mmt4d fusion at flow level, remove the check: https://github.com/iree-org/iree/blob/e41e71c522e5bfdeb0220a2c24d34c4a70dee44f/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp#L613-L617

pashu123 commented 4 months ago

Thanks, @hanhanW, for the description. I will start working on this.

pashu123 commented 4 months ago

I will try to give a picture of what the new data tiling infra would look like after a discussion with @hanhanW. Feedback is welcome. I will start with the IR in the dispatch (I am omitting flow-level details, just the meat of computation). The IR consists of broadcast + matmul + ReLu (I have removed the bias add just for clarity).

The below-mentioned IR is present inside a dispatch. (Again, I am omitting flow-level details just for clarity) Phase 0:

#map = affine_map<(d0, d1, d2) -> (d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
module {
  util.func public @broadcast_matmul_relu_dispatch_0(%arg0: tensor<?x?x3200xf32>, %arg1: tensor<8640x3200xf16>) -> tensor<?x?x8640xf32> {
    %cst = arith.constant 0.000000e+00 : f32
    %c0 = arith.constant 0 : index
    %c1 = arith.constant 1 : index
    %dim = tensor.dim %arg0, %c0 : tensor<?x?x3200xf32>
    %dim_0 = tensor.dim %arg0, %c1 : tensor<?x?x3200xf32>
    %0 = tensor.empty(%dim) : tensor<?x8640x3200xf16>
    %1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg1 : tensor<8640x3200xf16>) outs(%0 : tensor<?x8640x3200xf16>) {
    ^bb0(%in: f16, %out: f16):
      linalg.yield %in : f16
    } -> tensor<?x8640x3200xf16>
    %2 = tensor.empty(%dim, %dim_0) : tensor<?x?x8640xf32>
    %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<?x?x8640xf32>) -> tensor<?x?x8640xf32>
    %4 = linalg.batch_matmul_transpose_b ins(%arg0, %1 : tensor<?x?x3200xf32>, tensor<?x8640x3200xf16>) outs(%3 : tensor<?x?x8640xf32>) -> tensor<?x?x8640xf32>
    %5 = linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%4 : tensor<?x?x8640xf32>) outs(%2 : tensor<?x?x8640xf32>) {
    ^bb0(%in: f32, %out: f32):
      %6 = arith.maximumf %in, %cst : f32
      linalg.yield %6 : f32
    } -> tensor<?x?x8640xf32>
    util.return %5 : tensor<?x?x8640xf32>
  }
}

Next, the deferred (Previously, it was happening much earlier; now it occurs after the dispatch formation) materialization passes would kick in, and the IR would look like this: I would not go over details, but details are involved when deferring the materialization after dispatch formation. For more details, see: https://github.com/iree-org/iree/issues/17718 Phase 2: (Deffered Materialization)

util.func public @broadcast_matmul_relu_dispatch_1(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub, iree.reflection = {iree.abi.declaration = "sync func @matmul_broad(%input0: tensor<?x?x3200xf32>, %input1: tensor<8640x3200xf16>) -> (%output0: tensor<?x?x8640xf32>)"}} {
  %cst = arith.constant 0.000000e+00 : f16
  %c1 = arith.constant 1 : index
  %c0 = arith.constant 0 : index
  %cst_0 = arith.constant 0.000000e+00 : f32
  %0 = hal.buffer_view.dim<%arg0 : !hal.buffer_view>[0] : index
  %1 = hal.buffer_view.dim<%arg0 : !hal.buffer_view>[1] : index
  %2 = hal.tensor.import %arg0 "input0" : !hal.buffer_view -> tensor<?x?x3200xf32>{%0, %1}
  %3 = hal.tensor.import %arg1 "input1" : !hal.buffer_view -> tensor<8640x3200xf16>
  %4 = tensor.empty(%0) : tensor<?x8640x3200xf16>
  %5 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%3 : tensor<8640x3200xf16>) outs(%4 : tensor<?x8640x3200xf16>) {
  ^bb0(%in: f16, %out: f16):
    linalg.yield %in : f16
  } -> tensor<?x8640x3200xf16>
  %6 = tensor.empty(%0, %1) : tensor<?x?x8640xf32>
  %dim = tensor.dim %2, %c0 : tensor<?x?x3200xf32>
  %dim_1 = tensor.dim %2, %c1 : tensor<?x?x3200xf32>
  %7 = affine.apply affine_map<()[s0] -> (s0 ceildiv 16)>()[%dim_1]
  %8 = tensor.empty(%dim, %7) : tensor<?x?x3200x16x1xf32>
  %pack = tensor.pack %2 padding_value(%cst_0 : f32) outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [16, 1] into %8 : tensor<?x?x3200xf32> -> tensor<?x?x3200x16x1xf32>
  %9 = tensor.empty(%0) : tensor<?x540x3200x16x1xf16>
  %pack_2 = tensor.pack %5 padding_value(%cst : f16) outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [16, 1] into %9 : tensor<?x8640x3200xf16> -> tensor<?x540x3200x16x1xf16>
  %10 = affine.apply affine_map<()[s0, s1, s2] -> (-s1 + s2 + (s1 ceildiv s0) * s0)>()[%c1, %0, %0]
  %11 = affine.apply affine_map<()[s0] -> (s0 ceildiv 16)>()[%1]
  %12 = tensor.empty(%10, %11) : tensor<?x?x540x16x16xf32>
  %13 = linalg.fill ins(%cst_0 : f32) outs(%12 : tensor<?x?x540x16x16xf32>) -> tensor<?x?x540x16x16xf32>
  %14 = linalg.batch_mmt4d ins(%pack, %pack_2 : tensor<?x?x3200x16x1xf32>, tensor<?x540x3200x16x1xf16>) outs(%13 : tensor<?x?x540x16x16xf32>) -> tensor<?x?x540x16x16xf32>
  %15 = tensor.empty(%0, %1) : tensor<?x?x8640xf32>
  %unpack = tensor.unpack %14 outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [16, 16] into %15 : tensor<?x?x540x16x16xf32> -> tensor<?x?x8640xf32>
  %16 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%unpack : tensor<?x?x8640xf32>) outs(%6 : tensor<?x?x8640xf32>) {
  ^bb0(%in: f32, %out: f32):
    %18 = arith.maximumf %in, %cst_0 : f32
    linalg.yield %18 : f32
  } -> tensor<?x?x8640xf32>
  %17 = hal.tensor.export %16 "output0" : tensor<?x?x8640xf32>{%0, %1} -> !hal.buffer_view
  util.return %17 : !hal.buffer_view
}

Phase 3: (Data Layout Propagation: To push packs and unpacks to the boundaries of the dispatch)

util.func public @broadcast_matmul_relu_dispatch_1(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub, iree.reflection = {iree.abi.declaration = "sync func @matmul_broad(%input0: tensor<?x?x3200xf32>, %input1: tensor<8640x3200xf16>) -> (%output0: tensor<?x?x8640xf32>)"}} {
  %c1 = arith.constant 1 : index
  %c0 = arith.constant 0 : index
  %cst = arith.constant 0.000000e+00 : f32
  %0 = hal.buffer_view.dim<%arg0 : !hal.buffer_view>[0] : index
  %1 = hal.buffer_view.dim<%arg0 : !hal.buffer_view>[1] : index
  %2 = hal.tensor.import %arg0 "input0" : !hal.buffer_view -> tensor<?x?x3200xf32>{%0, %1}
  %3 = hal.tensor.import %arg1 "input1" : !hal.buffer_view -> tensor<8640x3200xf16>
  %4 = tensor.empty(%0) : tensor<?x540x3200x16x1xf16>
  %5 = tensor.empty() : tensor<540x3200x16x1xf16>
  %pack = tensor.pack %3 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [16, 1] into %5 : tensor<8640x3200xf16> -> tensor<540x3200x16x1xf16>
  %6 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d1, d2, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%pack : tensor<540x3200x16x1xf16>) outs(%4 : tensor<?x540x3200x16x1xf16>) {
  ^bb0(%in: f16, %out: f16):
    linalg.yield %in : f16
  } -> tensor<?x540x3200x16x1xf16>
  %7 = tensor.empty(%0, %1) : tensor<?x?x8640xf32>
  %8 = affine.apply affine_map<()[s0] -> (s0 ceildiv 16)>()[%1]
  %9 = tensor.empty(%0, %8) : tensor<?x?x3200x16x1xf32>
  %pack_0 = tensor.pack %2 padding_value(%cst : f32) outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [16, 1] into %9 : tensor<?x?x3200xf32> -> tensor<?x?x3200x16x1xf32>
  %10 = tensor.empty(%0, %8) : tensor<?x?x540x16x16xf32>
  %11 = linalg.fill ins(%cst : f32) outs(%10 : tensor<?x?x540x16x16xf32>) -> tensor<?x?x540x16x16xf32>
  %12 = linalg.batch_mmt4d ins(%pack_0, %6 : tensor<?x?x3200x16x1xf32>, tensor<?x540x3200x16x1xf16>) outs(%11 : tensor<?x?x540x16x16xf32>) -> tensor<?x?x540x16x16xf32>
  %unpack = tensor.unpack %12 outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [16, 16] into %7 : tensor<?x?x540x16x16xf32> -> tensor<?x?x8640xf32>
  %dim = tensor.dim %7, %c0 : tensor<?x?x8640xf32>
  %dim_1 = tensor.dim %7, %c1 : tensor<?x?x8640xf32>
  %13 = affine.apply affine_map<()[s0] -> (s0 ceildiv 16)>()[%dim_1]
  %14 = tensor.empty(%dim, %13) : tensor<?x?x540x16x16xf32>
  %dim_2 = tensor.dim %unpack, %c0 : tensor<?x?x8640xf32>
  %dim_3 = tensor.dim %unpack, %c1 : tensor<?x?x8640xf32>
  %15 = affine.apply affine_map<()[s0] -> (s0 ceildiv 16)>()[%dim_3]
  %16 = tensor.empty(%dim_2, %15) : tensor<?x?x540x16x16xf32>
  %pack_4 = tensor.pack %unpack outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [16, 16] into %16 : tensor<?x?x8640xf32> -> tensor<?x?x540x16x16xf32>
  %17 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%pack_4 : tensor<?x?x540x16x16xf32>) outs(%14 : tensor<?x?x540x16x16xf32>) {
  ^bb0(%in: f32, %out: f32):
    %21 = arith.maximumf %in, %cst : f32
    linalg.yield %21 : f32
  } -> tensor<?x?x540x16x16xf32>
  %dim_5 = tensor.dim %17, %c0 : tensor<?x?x540x16x16xf32>
  %dim_6 = tensor.dim %17, %c1 : tensor<?x?x540x16x16xf32>
  %18 = affine.apply affine_map<()[s0] -> (s0 * 16)>()[%dim_6]
  %19 = tensor.empty(%dim_5, %18) : tensor<?x?x8640xf32>
  %unpack_7 = tensor.unpack %17 outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [16, 16] into %19 : tensor<?x?x540x16x16xf32> -> tensor<?x?x8640xf32>
  %20 = hal.tensor.export %unpack_7 "output0" : tensor<?x?x8640xf32>{%0, %1} -> !hal.buffer_view
  util.return %20 : !hal.buffer_view
}

Phase 4: (Tile and Fuse Consumer & Producer. IR omitted just the details are mentioned) I've been omitting the lowering_config_attr attached to the root op (matmul in the above case) Suppose the lowering config attached to the batch_mmt4d op is

%46 = linalg.batch_mmt4d {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[1, 1, 1, 0, 0, 0, 0], [1, 1, 1, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [1, 1, 1, 0, 16, 16, 0], [0, 0, 0, 1, 0, 0, 1], [0, 0, 0, 0, 0, 0, 0]]>} ins(%41, %42 : tensor<?x?x3200x16x1xf32>, tensor<?x540x3200x16x1xf16>) outs(%45 : tensor<?x?x540x16x16xf32>) -> tensor<?x?x540x16x16xf32> Based on the lowering config attribute, we tile the root op (mmt4d in this case) and do consumer(relu) + producer(broadcast) fusion. Thanks to https://github.com/llvm/llvm-project/commit/2b2ce50fe843b5b550806a0ab15b06cd5c405d48

Phase 5 (Remove unnecessary buffer allocation) It has been seen in the past that when fusing broadcast with the consumer (batch_mmt4d in the above case), it might allocate unnecessary buffers which can be removed. More details along this line will be added.

hanhanW commented 4 months ago

Phase 3: (Data Layout Propagation: To push packs and unpacks to the boundaries of the dispatch)

The result of phase 3 is not what we're looking for in codegen input. Because some pack ops will be formed into other dispatches. E.g., %pack and %pack_0 will go with their producer. So the input of codegen would be something like:

  %rhs = linalg.generic ... // broadcast
  %13 = linalg.fill ins(%cst_0 : f32) outs(%12 : tensor<?x?x540x16x16xf32>) -> tensor<?x?x540x16x16xf32>
  %14 = linalg.batch_mmt4d ins(%lhs, %rhs : tensor<?x?x3200x16x1xf32>, tensor<?x540x3200x16x1xf16>) outs(%13 : tensor<?x?x540x16x16xf32>) -> tensor<?x?x540x16x16xf32>
  %15 = tensor.empty(%0, %1) : tensor<?x?x8640xf32>
  %unpack = tensor.unpack %14 outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [16, 16] into %15 : tensor<?x?x540x16x16xf32> -> tensor<?x?x8640xf32>
  %16 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%unpack : tensor<?x?x8640xf32>) outs(%6 : tensor<?x?x8640xf32>) {
  ^bb0(%in: f32, %out: f32):
    %18 = arith.maximumf %in, %cst_0 : f32
    linalg.yield %18 : f32
  } -> tensor<?x?x8640xf32>
pashu123 commented 4 months ago

Phase 3: (Data Layout Propagation: To push packs and unpacks to the boundaries of the dispatch)

The result of phase 3 is not what we're looking for in codegen input. Because some pack ops will be formed into other dispatches. E.g., %pack and %pack_0 will go with their producer. So the input of codegen would be something like:

  %rhs = linalg.generic ... // broadcast
  %13 = linalg.fill ins(%cst_0 : f32) outs(%12 : tensor<?x?x540x16x16xf32>) -> tensor<?x?x540x16x16xf32>
  %14 = linalg.batch_mmt4d ins(%lhs, %rhs : tensor<?x?x3200x16x1xf32>, tensor<?x540x3200x16x1xf16>) outs(%13 : tensor<?x?x540x16x16xf32>) -> tensor<?x?x540x16x16xf32>
  %15 = tensor.empty(%0, %1) : tensor<?x?x8640xf32>
  %unpack = tensor.unpack %14 outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [16, 16] into %15 : tensor<?x?x540x16x16xf32> -> tensor<?x?x8640xf32>
  %16 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%unpack : tensor<?x?x8640xf32>) outs(%6 : tensor<?x?x8640xf32>) {
  ^bb0(%in: f32, %out: f32):
    %18 = arith.maximumf %in, %cst_0 : f32
    linalg.yield %18 : f32
  } -> tensor<?x?x8640xf32>

I see; after materialization within dispatches, you will be propagating the packs and unpacks across dispatches.

hanhanW commented 4 months ago

after materialization within dispatches, you will be propagating the packs and unpacks across dispatches.

The materialization happens at codegen level. And the propagation also happens at codegen level. Below is where the materialization happens. And we will add the propagation to LLVMCPU/Passes.cpp.

https://github.com/iree-org/iree/blob/695e1932dd6cf91f2de5fc1415f10fe85fd269f0/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp#L739-L755

After thinking a while, I think we don't need to use the most complicated example to start the work. I think we can incrementally build the pipeline. How about starting with https://gist.github.com/hanhanW/075f4881664ce095d4af49a29842a6ba#file-z-mlir example? In this example, we need to implement Phase 4 details. It is a good start. In the meantime, I'm working on the fusion flow. Then I can provide other concrete examples later.

pashu123 commented 4 months ago

after materialization within dispatches, you will be propagating the packs and unpacks across dispatches.

The materialization happens at codegen level. And the propagation also happens at codegen level. Below is where the materialization happens. And we will add the propagation to LLVMCPU/Passes.cpp.

https://github.com/iree-org/iree/blob/695e1932dd6cf91f2de5fc1415f10fe85fd269f0/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp#L739-L755

After thinking a while, I think we don't need to use the most complicated example to start the work. I think we can incrementally build the pipeline. How about starting with https://gist.github.com/hanhanW/075f4881664ce095d4af49a29842a6ba#file-z-mlir example? In this example, we need to implement Phase 4 details. It is a good start. In the meantime, I'm working on the fusion flow. Then I can provide other concrete examples later.

I see, sounds good.

Max191 commented 3 months ago

@pashu123 Posting some test cases here after our discussion.

  1. Complex case: Batch matmul with full dequantization ops both the LHS and RHS. Not super realistic in terms of what we would find in a model, but is close enough to get an idea of what needs to be supported. https://gist.github.com/Max191/eab0aca68f7c3354967310d5df6e819a

That's all I have for right now. I rebuilt IREE and I forgot to generate more cases, so I will edit this comment with more cases once I have them.