Open hanhanW opened 10 months ago
The performance issue is come from bad configurations on generic op. It is generating scalar codes. One of potential solutions is to set vector level tile sizes to zeros when ukernel is enabled; let setLoweringConfigForComputeOps
set the vector tile sizes for generic ops.
Thanks Hanhan, I am digesting this slowly.
- The codegen assumes that only leading parallel dims are shared. This is not the case in the example because the result indexing_map (in mmt4d op) is
(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)
while the indexing_map (in generic op) is identical (i.e.,(d0, d1, d2, d3) -> (d0, d1, d2, d3)
). This leads a bug in multi lowering_config. The mapping is not taken into account in the method.
Ok, this does need better matching of producer iteration space to consumer iteration space. This is a problem with multi lowering config... Something to just note.
- LLVMCPUTileAndFuse should ignore
iree_codegen.ukernel.generic
op. The iteration domain information is gone after converting a linalg op toukernel.generic
op. Thus, we are not able to tile and fuse the ukernel op.- If we want to enable ukernel in fusion cases, we need to revisit when to convert the mmt4d op to ukernel op. It can either happen after distribution or first level of TileAndFuse. The former one could introduce a big stack buffer because the parallel dimensions are only tiled for distribution. The latter one needs more investigation.
I think ukernel lowering should happen after all tiling is done. It basically on par with vectorization stage.
With recent fixes, I'm able to codegen mmt4d fusion with and without ukernels. The lowering_config is not correct because the lowering_config is not set correctly on generic ops. The consumer becomes scalars.
1606: 41 0f 18 8a 14 02 00 prefetcht0 BYTE PTR [r10+0x214]
160d: 00
160e: 49 83 c1 60 add r9,0x60
1612: 49 83 c2 18 add r10,0x18
1616: 49 81 f9 00 60 00 00 cmp r9,0x6000
161d: 0f 85 2d ff ff ff jne 1550 <iree_hal_executable_library_query-0x240>
1623: 62 d1 7c 48 11 40 01 vmovups ZMMWORD PTR [r8+0x40],zmm0
162a: 49 83 e8 80 sub r8,0xffffffffffffff80
162e: 48 81 c6 00 80 01 00 add rsi,0x18000
1635: 83 c7 02 add edi,0x2
1638: 81 ff c0 00 00 00 cmp edi,0xc0
163e: 0f 85 fc fd ff ff jne 1440 <iree_hal_executable_library_query-0x350>
1644: ba 3c 00 00 00 mov edx,0x3c
1649: 0f 1f 80 00 00 00 00 nop DWORD PTR [rax+0x0]
1650: c5 fa 10 44 14 c4 vmovss xmm0,DWORD PTR [rsp+rdx*1-0x3c]
1656: c5 fa 58 44 10 c4 vaddss xmm0,xmm0,DWORD PTR [rax+rdx*1-0x3c]
165c: c5 fa 11 44 11 c4 vmovss DWORD PTR [rcx+rdx*1-0x3c],xmm0
1662: c5 fa 10 44 14 c8 vmovss xmm0,DWORD PTR [rsp+rdx*1-0x38]
1668: c5 fa 58 44 10 c8 vaddss xmm0,xmm0,DWORD PTR [rax+rdx*1-0x38]
166e: c5 fa 11 44 11 c8 vmovss DWORD PTR [rcx+rdx*1-0x38],xmm0
1674: c5 fa 10 44 14 cc vmovss xmm0,DWORD PTR [rsp+rdx*1-0x34]
167a: c5 fa 58 44 10 cc vaddss xmm0,xmm0,DWORD PTR [rax+rdx*1-0x34]
1680: c5 fa 11 44 11 cc vmovss DWORD PTR [rcx+rdx*1-0x34],xmm0
1686: c5 fa 10 44 14 d0 vmovss xmm0,DWORD PTR [rsp+rdx*1-0x30]
168c: c5 fa 58 44 10 d0 vaddss xmm0,xmm0,DWORD PTR [rax+rdx*1-0x30]
1692: c5 fa 11 44 11 d0 vmovss DWORD PTR [rcx+rdx*1-0x30],xmm0
1698: c5 fa 10 44 14 d4 vmovss xmm0,DWORD PTR [rsp+rdx*1-0x2c]
169e: c5 fa 58 44 10 d4 vaddss xmm0,xmm0,DWORD PTR [rax+rdx*1-0x2c]
16a4: c5 fa 11 44 11 d4 vmovss DWORD PTR [rcx+rdx*1-0x2c],xmm0
16aa: c5 fa 10 44 14 d8 vmovss xmm0,DWORD PTR [rsp+rdx*1-0x28]
16b0: c5 fa 58 44 10 d8 vaddss xmm0,xmm0,DWORD PTR [rax+rdx*1-0x28]
16b6: c5 fa 11 44 11 d8 vmovss DWORD PTR [rcx+rdx*1-0x28],xmm0
16bc: c5 fa 10 44 14 dc vmovss xmm0,DWORD PTR [rsp+rdx*1-0x24]
16c2: c5 fa 58 44 10 dc vaddss xmm0,xmm0,DWORD PTR [rax+rdx*1-0x24]
16c8: c5 fa 11 44 11 dc vmovss DWORD PTR [rcx+rdx*1-0x24],xmm0
16ce: c5 fa 10 44 14 e0 vmovss xmm0,DWORD PTR [rsp+rdx*1-0x20]
16d4: c5 fa 58 44 10 e0 vaddss xmm0,xmm0,DWORD PTR [rax+rdx*1-0x20]
16da: c5 fa 11 44 11 e0 vmovss DWORD PTR [rcx+rdx*1-0x20],xmm0
16e0: c5 fa 10 44 14 e4 vmovss xmm0,DWORD PTR [rsp+rdx*1-0x1c]
16e6: c5 fa 58 44 10 e4 vaddss xmm0,xmm0,DWORD PTR [rax+rdx*1-0x1c]
16ec: c5 fa 11 44 11 e4 vmovss DWORD PTR [rcx+rdx*1-0x1c],xmm0
16f2: c5 fa 10 44 14 e8 vmovss xmm0,DWORD PTR [rsp+rdx*1-0x18]
16f8: c5 fa 58 44 10 e8 vaddss xmm0,xmm0,DWORD PTR [rax+rdx*1-0x18]
16fe: c5 fa 11 44 11 e8 vmovss DWORD PTR [rcx+rdx*1-0x18],xmm0
1704: c5 fa 10 44 14 ec vmovss xmm0,DWORD PTR [rsp+rdx*1-0x14]
170a: c5 fa 58 44 10 ec vaddss xmm0,xmm0,DWORD PTR [rax+rdx*1-0x14]
1710: c5 fa 11 44 11 ec vmovss DWORD PTR [rcx+rdx*1-0x14],xmm0
1716: c5 fa 10 44 14 f0 vmovss xmm0,DWORD PTR [rsp+rdx*1-0x10]
171c: c5 fa 58 44 10 f0 vaddss xmm0,xmm0,DWORD PTR [rax+rdx*1-0x10]
1722: c5 fa 11 44 11 f0 vmovss DWORD PTR [rcx+rdx*1-0x10],xmm0
1728: c5 fa 10 44 14 f4 vmovss xmm0,DWORD PTR [rsp+rdx*1-0xc]
172e: c5 fa 58 44 10 f4 vaddss xmm0,xmm0,DWORD PTR [rax+rdx*1-0xc]
1734: c5 fa 11 44 11 f4 vmovss DWORD PTR [rcx+rdx*1-0xc],xmm0
173a: c5 fa 10 44 14 f8 vmovss xmm0,DWORD PTR [rsp+rdx*1-0x8]
1740: c5 fa 58 44 10 f8 vaddss xmm0,xmm0,DWORD PTR [rax+rdx*1-0x8]
1746: c5 fa 11 44 11 f8 vmovss DWORD PTR [rcx+rdx*1-0x8],xmm0
174c: c5 fa 10 44 14 fc vmovss xmm0,DWORD PTR [rsp+rdx*1-0x4]
1752: c5 fa 58 44 10 fc vaddss xmm0,xmm0,DWORD PTR [rax+rdx*1-0x4]
1758: c5 fa 11 44 11 fc vmovss DWORD PTR [rcx+rdx*1-0x4],xmm0
175e: c5 fa 10 04 14 vmovss xmm0,DWORD PTR [rsp+rdx*1]
1763: c5 fa 58 04 10 vaddss xmm0,xmm0,DWORD PTR [rax+rdx*1]
1768: c5 fa 11 04 11 vmovss DWORD PTR [rcx+rdx*1],xmm0
176d: 48 83 c2 40 add rdx,0x40
Here is the IR after vectorization and bufferization:
func.func @mmt4d_fusion_dispatch_0_mmt4d_1x192x768x1x16x1_f32() {
%c0 = arith.constant 0 : index
%c192 = arith.constant 192 : index
%c1025_i32 = arith.constant 1025 : i32
%c1_i32 = arith.constant 1 : i32
%c16_i32 = arith.constant 16 : i32
%c768 = arith.constant 768 : index
%c2 = arith.constant 2 : index
%c1 = arith.constant 1 : index
%c16 = arith.constant 16 : index
%cst = arith.constant 0.000000e+00 : f32
%alloca = memref.alloca() {alignment = 64 : i64} : memref<1x2x1x16xf32>
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<1x768x1x1xf32, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %0, 64 : memref<1x768x1x1xf32, #hal.descriptor_type<storage_buffer>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<192x768x16x1xf32, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %1, 64 : memref<192x768x16x1xf32, #hal.descriptor_type<storage_buffer>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<1x192x1x16xf32, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %2, 64 : memref<1x192x1x16xf32, #hal.descriptor_type<storage_buffer>>
%3 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) alignment(64) offset(%c0) : memref<1x192x1x16xf32, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %3, 64 : memref<1x192x1x16xf32, #hal.descriptor_type<storage_buffer>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%workgroup_count_x = hal.interface.workgroup.count[0] : index
%4 = affine.apply affine_map<()[s0] -> (s0 * 2)>()[%workgroup_id_x]
%5 = affine.apply affine_map<()[s0] -> (s0 * 2)>()[%workgroup_count_x]
scf.for %arg0 = %4 to %c192 step %5 {
%subview = memref.subview %3[0, %arg0, 0, 0] [1, 2, 1, 16] [1, 1, 1, 1] : memref<1x192x1x16xf32, #hal.descriptor_type<storage_buffer>> to memref<1x2x1x16xf32, strided<[3072, 16, 16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_0 = memref.subview %1[%arg0, 0, 0, 0] [2, 768, 16, 1] [1, 1, 1, 1] : memref<192x768x16x1xf32, #hal.descriptor_type<storage_buffer>> to memref<2x768x16x1xf32, strided<[12288, 16, 1, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
iree_codegen.ukernel.generic "iree_uk_mmt4d" ins(%0, %subview_0 : memref<1x768x1x1xf32, #hal.descriptor_type<storage_buffer>>, memref<2x768x16x1xf32, strided<[12288, 16, 1, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) outs(%alloca : memref<1x2x1x16xf32>) (%c1, %c2, %c768, %c1_i32, %c16_i32, %c1_i32, %c1025_i32 : index, index, index, i32, i32, i32, i32) fn_def_attrs {hal.import.bitcode = true, hal.import.cconv = 1 : i32, hal.import.fields = ["processor_data"]} strided_outer_dims(1)
%subview_1 = memref.subview %2[0, %arg0, 0, 0] [1, 2, 1, 16] [1, 1, 1, 1] : memref<1x192x1x16xf32, #hal.descriptor_type<storage_buffer>> to memref<1x2x1x16xf32, strided<[3072, 16, 16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
scf.for %arg1 = %c0 to %c2 step %c1 {
scf.for %arg2 = %c0 to %c16 step %c1 {
%6 = vector.transfer_read %alloca[%c0, %arg1, %c0, %arg2], %cst {in_bounds = [true, true, true, true]} : memref<1x2x1x16xf32>, vector<1x1x1x1xf32>
%7 = vector.transfer_read %subview_1[%c0, %arg1, %c0, %arg2], %cst {in_bounds = [true, true, true, true]} : memref<1x2x1x16xf32, strided<[3072, 16, 16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1x1x1x1xf32>
%8 = arith.addf %6, %7 : vector<1x1x1x1xf32>
vector.transfer_write %8, %subview[%c0, %arg1, %c0, %arg2] {in_bounds = [true, true, true, true]} : vector<1x1x1x1xf32>, memref<1x2x1x16xf32, strided<[3072, 16, 16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
}
}
}
return
}
Few issues here:
1x2x1x16xf32
stack buffer. The outer dims are limited by distribution tile sizes. This is the tradeoff for mixed codegen and ukernels. It will be bounded by smaller sizes if we kick in cache level tiling. Furthermore, it won't be an issue if the pure codegen approach is applied.[1, 1, 1, 16]
. What we need is
@pzread and I are working on a draft for (2). And I will work on (3), which should refresh LLVMCPUMmt4dVectorLowering pass and leverage some features to LLVMCPUVectorLowering pass.
Ok, this looks good... Thanks for pushing on this.
Currently you are looking at cases where the pack becomes a reshape. We can also try for cases where pack is not changed to a reshape. All you would need to do is to have a pass that propagates the unpack after the elementwise operation. We dont need to do massive graph propagation, to start with. We can just propagate it past one elementwise operation. Then the elementwise operation can fuse with the mmt4d (both being in packed layout). But everything you listed above is a pre-cursor to having that working.
Prioritize the issue and flesh out more details. The issue depends on https://github.com/iree-org/iree/issues/17718. We will need to support the codegen in the new data-tiling pipeline, because it enables fusion. Today, we have tile root op and fuse consumers and producers
TilingInterface API. We no longer need lowering_config propagation for mmt4d fusion. Instead, we can create the TileAndFuse v2 pass. In the new pass, we don't use the last operation to start with tiling. We can directly tile on root op and fuse both producers and consumers into the scf loops.
We don't need to wait for the parent issue. We can start some work using the the above example. (need some hack) I can provide the some codegen examples when someone picks this up.
@pashu123 here is an example that you can start with: https://gist.github.com/hanhanW/075f4881664ce095d4af49a29842a6ba#file-z-mlir
You can follow https://github.com/iree-org/iree/commit/3d23684f2510264e2793aeb49e53269fd168e4f3 to create a new strategy and add pipeline tests like this. In the example, presets the lowering_config and strategy and run the LLVMGPUTileAndFuse lowering pipeline. You should be able to do the same thing in your prototype.
Also, please help create other test cases like broadcast + batch_mmt4d -> elementwise
and other dequant ops. That will be the input of the new pipeline.
===
side note: to enable mmt4d fusion at flow level, remove the check: https://github.com/iree-org/iree/blob/e41e71c522e5bfdeb0220a2c24d34c4a70dee44f/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp#L613-L617
Thanks, @hanhanW, for the description. I will start working on this.
I will try to give a picture of what the new data tiling infra would look like after a discussion with @hanhanW. Feedback is welcome. I will start with the IR in the dispatch (I am omitting flow-level details, just the meat of computation). The IR consists of broadcast + matmul + ReLu (I have removed the bias add just for clarity).
The below-mentioned IR is present inside a dispatch. (Again, I am omitting flow-level details just for clarity) Phase 0:
#map = affine_map<(d0, d1, d2) -> (d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
module {
util.func public @broadcast_matmul_relu_dispatch_0(%arg0: tensor<?x?x3200xf32>, %arg1: tensor<8640x3200xf16>) -> tensor<?x?x8640xf32> {
%cst = arith.constant 0.000000e+00 : f32
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%dim = tensor.dim %arg0, %c0 : tensor<?x?x3200xf32>
%dim_0 = tensor.dim %arg0, %c1 : tensor<?x?x3200xf32>
%0 = tensor.empty(%dim) : tensor<?x8640x3200xf16>
%1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg1 : tensor<8640x3200xf16>) outs(%0 : tensor<?x8640x3200xf16>) {
^bb0(%in: f16, %out: f16):
linalg.yield %in : f16
} -> tensor<?x8640x3200xf16>
%2 = tensor.empty(%dim, %dim_0) : tensor<?x?x8640xf32>
%3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<?x?x8640xf32>) -> tensor<?x?x8640xf32>
%4 = linalg.batch_matmul_transpose_b ins(%arg0, %1 : tensor<?x?x3200xf32>, tensor<?x8640x3200xf16>) outs(%3 : tensor<?x?x8640xf32>) -> tensor<?x?x8640xf32>
%5 = linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%4 : tensor<?x?x8640xf32>) outs(%2 : tensor<?x?x8640xf32>) {
^bb0(%in: f32, %out: f32):
%6 = arith.maximumf %in, %cst : f32
linalg.yield %6 : f32
} -> tensor<?x?x8640xf32>
util.return %5 : tensor<?x?x8640xf32>
}
}
Next, the deferred (Previously, it was happening much earlier; now it occurs after the dispatch formation) materialization passes would kick in, and the IR would look like this: I would not go over details, but details are involved when deferring the materialization after dispatch formation. For more details, see: https://github.com/iree-org/iree/issues/17718 Phase 2: (Deffered Materialization)
util.func public @broadcast_matmul_relu_dispatch_1(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub, iree.reflection = {iree.abi.declaration = "sync func @matmul_broad(%input0: tensor<?x?x3200xf32>, %input1: tensor<8640x3200xf16>) -> (%output0: tensor<?x?x8640xf32>)"}} {
%cst = arith.constant 0.000000e+00 : f16
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%cst_0 = arith.constant 0.000000e+00 : f32
%0 = hal.buffer_view.dim<%arg0 : !hal.buffer_view>[0] : index
%1 = hal.buffer_view.dim<%arg0 : !hal.buffer_view>[1] : index
%2 = hal.tensor.import %arg0 "input0" : !hal.buffer_view -> tensor<?x?x3200xf32>{%0, %1}
%3 = hal.tensor.import %arg1 "input1" : !hal.buffer_view -> tensor<8640x3200xf16>
%4 = tensor.empty(%0) : tensor<?x8640x3200xf16>
%5 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%3 : tensor<8640x3200xf16>) outs(%4 : tensor<?x8640x3200xf16>) {
^bb0(%in: f16, %out: f16):
linalg.yield %in : f16
} -> tensor<?x8640x3200xf16>
%6 = tensor.empty(%0, %1) : tensor<?x?x8640xf32>
%dim = tensor.dim %2, %c0 : tensor<?x?x3200xf32>
%dim_1 = tensor.dim %2, %c1 : tensor<?x?x3200xf32>
%7 = affine.apply affine_map<()[s0] -> (s0 ceildiv 16)>()[%dim_1]
%8 = tensor.empty(%dim, %7) : tensor<?x?x3200x16x1xf32>
%pack = tensor.pack %2 padding_value(%cst_0 : f32) outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [16, 1] into %8 : tensor<?x?x3200xf32> -> tensor<?x?x3200x16x1xf32>
%9 = tensor.empty(%0) : tensor<?x540x3200x16x1xf16>
%pack_2 = tensor.pack %5 padding_value(%cst : f16) outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [16, 1] into %9 : tensor<?x8640x3200xf16> -> tensor<?x540x3200x16x1xf16>
%10 = affine.apply affine_map<()[s0, s1, s2] -> (-s1 + s2 + (s1 ceildiv s0) * s0)>()[%c1, %0, %0]
%11 = affine.apply affine_map<()[s0] -> (s0 ceildiv 16)>()[%1]
%12 = tensor.empty(%10, %11) : tensor<?x?x540x16x16xf32>
%13 = linalg.fill ins(%cst_0 : f32) outs(%12 : tensor<?x?x540x16x16xf32>) -> tensor<?x?x540x16x16xf32>
%14 = linalg.batch_mmt4d ins(%pack, %pack_2 : tensor<?x?x3200x16x1xf32>, tensor<?x540x3200x16x1xf16>) outs(%13 : tensor<?x?x540x16x16xf32>) -> tensor<?x?x540x16x16xf32>
%15 = tensor.empty(%0, %1) : tensor<?x?x8640xf32>
%unpack = tensor.unpack %14 outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [16, 16] into %15 : tensor<?x?x540x16x16xf32> -> tensor<?x?x8640xf32>
%16 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%unpack : tensor<?x?x8640xf32>) outs(%6 : tensor<?x?x8640xf32>) {
^bb0(%in: f32, %out: f32):
%18 = arith.maximumf %in, %cst_0 : f32
linalg.yield %18 : f32
} -> tensor<?x?x8640xf32>
%17 = hal.tensor.export %16 "output0" : tensor<?x?x8640xf32>{%0, %1} -> !hal.buffer_view
util.return %17 : !hal.buffer_view
}
Phase 3: (Data Layout Propagation: To push packs and unpacks to the boundaries of the dispatch)
util.func public @broadcast_matmul_relu_dispatch_1(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub, iree.reflection = {iree.abi.declaration = "sync func @matmul_broad(%input0: tensor<?x?x3200xf32>, %input1: tensor<8640x3200xf16>) -> (%output0: tensor<?x?x8640xf32>)"}} {
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
%0 = hal.buffer_view.dim<%arg0 : !hal.buffer_view>[0] : index
%1 = hal.buffer_view.dim<%arg0 : !hal.buffer_view>[1] : index
%2 = hal.tensor.import %arg0 "input0" : !hal.buffer_view -> tensor<?x?x3200xf32>{%0, %1}
%3 = hal.tensor.import %arg1 "input1" : !hal.buffer_view -> tensor<8640x3200xf16>
%4 = tensor.empty(%0) : tensor<?x540x3200x16x1xf16>
%5 = tensor.empty() : tensor<540x3200x16x1xf16>
%pack = tensor.pack %3 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [16, 1] into %5 : tensor<8640x3200xf16> -> tensor<540x3200x16x1xf16>
%6 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d1, d2, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%pack : tensor<540x3200x16x1xf16>) outs(%4 : tensor<?x540x3200x16x1xf16>) {
^bb0(%in: f16, %out: f16):
linalg.yield %in : f16
} -> tensor<?x540x3200x16x1xf16>
%7 = tensor.empty(%0, %1) : tensor<?x?x8640xf32>
%8 = affine.apply affine_map<()[s0] -> (s0 ceildiv 16)>()[%1]
%9 = tensor.empty(%0, %8) : tensor<?x?x3200x16x1xf32>
%pack_0 = tensor.pack %2 padding_value(%cst : f32) outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [16, 1] into %9 : tensor<?x?x3200xf32> -> tensor<?x?x3200x16x1xf32>
%10 = tensor.empty(%0, %8) : tensor<?x?x540x16x16xf32>
%11 = linalg.fill ins(%cst : f32) outs(%10 : tensor<?x?x540x16x16xf32>) -> tensor<?x?x540x16x16xf32>
%12 = linalg.batch_mmt4d ins(%pack_0, %6 : tensor<?x?x3200x16x1xf32>, tensor<?x540x3200x16x1xf16>) outs(%11 : tensor<?x?x540x16x16xf32>) -> tensor<?x?x540x16x16xf32>
%unpack = tensor.unpack %12 outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [16, 16] into %7 : tensor<?x?x540x16x16xf32> -> tensor<?x?x8640xf32>
%dim = tensor.dim %7, %c0 : tensor<?x?x8640xf32>
%dim_1 = tensor.dim %7, %c1 : tensor<?x?x8640xf32>
%13 = affine.apply affine_map<()[s0] -> (s0 ceildiv 16)>()[%dim_1]
%14 = tensor.empty(%dim, %13) : tensor<?x?x540x16x16xf32>
%dim_2 = tensor.dim %unpack, %c0 : tensor<?x?x8640xf32>
%dim_3 = tensor.dim %unpack, %c1 : tensor<?x?x8640xf32>
%15 = affine.apply affine_map<()[s0] -> (s0 ceildiv 16)>()[%dim_3]
%16 = tensor.empty(%dim_2, %15) : tensor<?x?x540x16x16xf32>
%pack_4 = tensor.pack %unpack outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [16, 16] into %16 : tensor<?x?x8640xf32> -> tensor<?x?x540x16x16xf32>
%17 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%pack_4 : tensor<?x?x540x16x16xf32>) outs(%14 : tensor<?x?x540x16x16xf32>) {
^bb0(%in: f32, %out: f32):
%21 = arith.maximumf %in, %cst : f32
linalg.yield %21 : f32
} -> tensor<?x?x540x16x16xf32>
%dim_5 = tensor.dim %17, %c0 : tensor<?x?x540x16x16xf32>
%dim_6 = tensor.dim %17, %c1 : tensor<?x?x540x16x16xf32>
%18 = affine.apply affine_map<()[s0] -> (s0 * 16)>()[%dim_6]
%19 = tensor.empty(%dim_5, %18) : tensor<?x?x8640xf32>
%unpack_7 = tensor.unpack %17 outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [16, 16] into %19 : tensor<?x?x540x16x16xf32> -> tensor<?x?x8640xf32>
%20 = hal.tensor.export %unpack_7 "output0" : tensor<?x?x8640xf32>{%0, %1} -> !hal.buffer_view
util.return %20 : !hal.buffer_view
}
Phase 4: (Tile and Fuse Consumer & Producer. IR omitted just the details are mentioned) I've been omitting the lowering_config_attr attached to the root op (matmul in the above case) Suppose the lowering config attached to the batch_mmt4d op is
%46 = linalg.batch_mmt4d {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[1, 1, 1, 0, 0, 0, 0], [1, 1, 1, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [1, 1, 1, 0, 16, 16, 0], [0, 0, 0, 1, 0, 0, 1], [0, 0, 0, 0, 0, 0, 0]]>} ins(%41, %42 : tensor<?x?x3200x16x1xf32>, tensor<?x540x3200x16x1xf16>) outs(%45 : tensor<?x?x540x16x16xf32>) -> tensor<?x?x540x16x16xf32>
Based on the lowering config attribute, we tile the root op (mmt4d in this case) and do consumer(relu) + producer(broadcast) fusion. Thanks to https://github.com/llvm/llvm-project/commit/2b2ce50fe843b5b550806a0ab15b06cd5c405d48
Phase 5 (Remove unnecessary buffer allocation) It has been seen in the past that when fusing broadcast with the consumer (batch_mmt4d in the above case), it might allocate unnecessary buffers which can be removed. More details along this line will be added.
Phase 3: (Data Layout Propagation: To push packs and unpacks to the boundaries of the dispatch)
The result of phase 3 is not what we're looking for in codegen input. Because some pack ops will be formed into other dispatches. E.g., %pack
and %pack_0
will go with their producer. So the input of codegen would be something like:
%rhs = linalg.generic ... // broadcast
%13 = linalg.fill ins(%cst_0 : f32) outs(%12 : tensor<?x?x540x16x16xf32>) -> tensor<?x?x540x16x16xf32>
%14 = linalg.batch_mmt4d ins(%lhs, %rhs : tensor<?x?x3200x16x1xf32>, tensor<?x540x3200x16x1xf16>) outs(%13 : tensor<?x?x540x16x16xf32>) -> tensor<?x?x540x16x16xf32>
%15 = tensor.empty(%0, %1) : tensor<?x?x8640xf32>
%unpack = tensor.unpack %14 outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [16, 16] into %15 : tensor<?x?x540x16x16xf32> -> tensor<?x?x8640xf32>
%16 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%unpack : tensor<?x?x8640xf32>) outs(%6 : tensor<?x?x8640xf32>) {
^bb0(%in: f32, %out: f32):
%18 = arith.maximumf %in, %cst_0 : f32
linalg.yield %18 : f32
} -> tensor<?x?x8640xf32>
Phase 3: (Data Layout Propagation: To push packs and unpacks to the boundaries of the dispatch)
The result of phase 3 is not what we're looking for in codegen input. Because some pack ops will be formed into other dispatches. E.g.,
%pack
and%pack_0
will go with their producer. So the input of codegen would be something like:%rhs = linalg.generic ... // broadcast %13 = linalg.fill ins(%cst_0 : f32) outs(%12 : tensor<?x?x540x16x16xf32>) -> tensor<?x?x540x16x16xf32> %14 = linalg.batch_mmt4d ins(%lhs, %rhs : tensor<?x?x3200x16x1xf32>, tensor<?x540x3200x16x1xf16>) outs(%13 : tensor<?x?x540x16x16xf32>) -> tensor<?x?x540x16x16xf32> %15 = tensor.empty(%0, %1) : tensor<?x?x8640xf32> %unpack = tensor.unpack %14 outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [16, 16] into %15 : tensor<?x?x540x16x16xf32> -> tensor<?x?x8640xf32> %16 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%unpack : tensor<?x?x8640xf32>) outs(%6 : tensor<?x?x8640xf32>) { ^bb0(%in: f32, %out: f32): %18 = arith.maximumf %in, %cst_0 : f32 linalg.yield %18 : f32 } -> tensor<?x?x8640xf32>
I see; after materialization within dispatches, you will be propagating the packs and unpacks across dispatches.
after materialization within dispatches, you will be propagating the packs and unpacks across dispatches.
The materialization happens at codegen level. And the propagation also happens at codegen level. Below is where the materialization happens. And we will add the propagation to LLVMCPU/Passes.cpp
.
After thinking a while, I think we don't need to use the most complicated example to start the work. I think we can incrementally build the pipeline. How about starting with https://gist.github.com/hanhanW/075f4881664ce095d4af49a29842a6ba#file-z-mlir example? In this example, we need to implement Phase 4
details. It is a good start. In the meantime, I'm working on the fusion flow. Then I can provide other concrete examples later.
after materialization within dispatches, you will be propagating the packs and unpacks across dispatches.
The materialization happens at codegen level. And the propagation also happens at codegen level. Below is where the materialization happens. And we will add the propagation to
LLVMCPU/Passes.cpp
.After thinking a while, I think we don't need to use the most complicated example to start the work. I think we can incrementally build the pipeline. How about starting with https://gist.github.com/hanhanW/075f4881664ce095d4af49a29842a6ba#file-z-mlir example? In this example, we need to implement
Phase 4
details. It is a good start. In the meantime, I'm working on the fusion flow. Then I can provide other concrete examples later.
I see, sounds good.
@pashu123 Posting some test cases here after our discussion.
That's all I have for right now. I rebuilt IREE and I forgot to generate more cases, so I will edit this comment with more cases once I have them.
We have opportunities to fuse mmt4d op with its consumers, but there are performance issues. This happens if we simplifies 1D pack/unpack to expand_shape/collapse_shape ops. Generally it is good because they become metadata op. This issue describes what's happening in mmt4d fusion. Below is the snippet from GPT2 model.
There are couple issues in current pipeline:
(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)
while the indexing_map (in generic op) is identical (i.e.,(d0, d1, d2, d3) -> (d0, d1, d2, d3)
). This leads a bug in multi lowering_config. The mapping is not taken into account in the method.iree_codegen.ukernel.generic
op. The iteration domain information is gone after converting a linalg op toukernel.generic
op. Thus, we are not able to tile and fuse the ukernel op.(The other option is to consider specialization. If this is fusion, we go with codegen. Otherwise, we go with ukernel path. I haven't explored this path, so no further comments.)