jiangqucheng / torch-mlir-hwacc

Torch-MLIR to HWACC
Other
0 stars 0 forks source link

MLIR `linalg.matmul` Capture and Rrewrite #2

Open jiangqucheng opened 9 months ago

jiangqucheng commented 9 months ago

Test of the pass rewriter workflow

Command Line

/home/qjiang/work/final_grad_prj/torch-mlir/build/bin/mlir-opt  -pass-pipeline='builtin.module(func.func(convert-linalg-to-hwacc))' ./my_torch_jit_app1.mlir_lowering_pipeline.before_linalg_matmul_lowering.mlir 

Log Output

Only op from linalg dialect got captured.

image

linalg.yield doesn't match hasBufferSemantics() therefore it is also not matched.

image

def Linalg_YieldOp : Linalg_Op<"yield", [Pure, ReturnLike, Terminator]>,
    Arguments<(ins Variadic<AnyType>:$values)> {
  let summary = "Linalg yield operation";
  let description = [{
    `linalg.yield` is a special terminator operation for blocks inside regions
    in `linalg` generic ops. It returns values to the immediately enclosing
    `linalg` generic op.

    Example:

    ```mlir
    linalg.yield %f0, %f1 : f32, f32

}]; let builders = [OpBuilder<(ins), [{ / nothing to do / }]>]; let hasCustomAssemblyFormat = 1; let hasVerifier = 1; }


## Input MLIR for pipeline capturing
```mlir
// my_torch_jit_app1.mlir_lowering_pipeline.before_linalg_matmul_lowering.mlir
#loc = loc(unknown)
#loc1 = loc("/scratch/qjiang/work/final_grad_prj/torch-mlir/mlir_venv/lib/python3.11/site-packages/torch/nn/modules/linear.py":116:15)
#loc2 = loc("/scratch/qjiang/work/final_grad_prj/torch-mlir/mlir_venv/lib/python3.11/site-packages/torch/nn/modules/container.py":217:20)
#loc3 = loc("/scratch/qjiang/work/final_grad_prj/playground/MyPyApp/my_torch_jit_app/my_torch_jit_app1.py":98:12)
#loc4 = loc("/scratch/qjiang/work/final_grad_prj/torch-mlir/mlir_venv/lib/python3.11/site-packages/torch/nn/functional.py":796:11)
#loc5 = loc("/scratch/qjiang/work/final_grad_prj/torch-mlir/mlir_venv/lib/python3.11/site-packages/torch/nn/modules/pooling.py":164:15)
#loc6 = loc("/scratch/qjiang/work/final_grad_prj/playground/MyPyApp/my_torch_jit_app/my_torch_jit_app1.py":94:12)
#loc7 = loc("/scratch/qjiang/work/final_grad_prj/torch-mlir/mlir_venv/lib/python3.11/site-packages/torch/nn/modules/conv.py":456:15)
#loc8 = loc("/scratch/qjiang/work/final_grad_prj/torch-mlir/mlir_venv/lib/python3.11/site-packages/torch/nn/modules/conv.py":460:15)
#loc9 = loc("/scratch/qjiang/work/final_grad_prj/torch-mlir/mlir_venv/lib/python3.11/site-packages/torch/nn/functional.py":1473:17)
#loc10 = loc("/scratch/qjiang/work/final_grad_prj/torch-mlir/mlir_venv/lib/python3.11/site-packages/torch/nn/modules/activation.py":101:15)
#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
#map1 = affine_map<(d0, d1, d2, d3) -> (d1)>
#map2 = affine_map<(d0, d1) -> (d0, d1)>
#map3 = affine_map<(d0, d1) -> (d1)>
#loc12 = loc("aten::linear"(#loc1))
#loc13 = loc("aten::max_pool2d"(#loc4))
#loc14 = loc("aten::conv2d"(#loc7))
#loc15 = loc("aten::relu"(#loc9))
#loc17 = loc(callsite(#loc12 at #loc2))
#loc18 = loc(callsite(#loc13 at #loc5))
#loc19 = loc(callsite(#loc14 at #loc8))
#loc20 = loc(callsite(#loc15 at #loc10))
#loc21 = loc(callsite(#loc17 at #loc3))
#loc22 = loc(callsite(#loc18 at #loc2))
#loc23 = loc(callsite(#loc19 at #loc2))
#loc24 = loc(callsite(#loc20 at #loc2))
#loc25 = loc(callsite(#loc22 at #loc6))
#loc26 = loc(callsite(#loc23 at #loc6))
#loc27 = loc(callsite(#loc24 at #loc6))
#loc28 = loc(callsite(#loc24 at #loc3))
module attributes {torch.debug_module_name = "MyNNModel"} {
  func.func private @refbackend_consume_func_return_mrf32(memref<*xf32>) attributes {llvm.emit_c_interface} loc(#loc)
  memref.global "private" @global_seed : memref<i64> = dense<0> loc(#loc)
  memref.global "private" constant @__constant_16x1x3x3xf32 : memref<16x1x3x3xf32> = dense_resource<__elided__> loc(#loc)
  memref.global "private" constant @__constant_16xf32 : memref<16xf32> = dense_resource<__elided__> loc(#loc)
  memref.global "private" constant @__constant_32x16x3x3xf32 : memref<32x16x3x3xf32> = dense_resource<__elided__> loc(#loc)
  memref.global "private" constant @__constant_32xf32 : memref<32xf32> = dense_resource<__elided__> loc(#loc)
  memref.global "private" constant @__constant_1024xf32 : memref<1024xf32> = dense_resource<__elided__> loc(#loc)
  memref.global "private" constant @__constant_10xf32 : memref<10xf32> = dense<[-0.0159965977, 0.0596813858, 0.0136541566, -0.00140823587, 0.0201260522, -0.00508531276, -0.0442607887, 0.0161830094, -0.023775503, -0.033404272]> loc(#loc)
  memref.global "private" constant @__constant_6272x1024xf32 : memref<6272x1024xf32> = dense_resource<__elided__> loc(#loc21)
  memref.global "private" constant @__constant_1024x10xf32 : memref<1024x10xf32> = dense_resource<__elided__> loc(#loc21)
  func.func @forward(%arg0: memref<*xf32> loc(unknown)) attributes {llvm.emit_c_interface} {
    %cst = arith.constant 0xFF800000 : f32 loc(#loc25)
    %cst_0 = arith.constant 0.000000e+00 : f32 loc(#loc26)
    %cast = memref.cast %arg0 : memref<*xf32> to memref<64x1x28x28xf32> loc(#loc)
    %0 = memref.get_global @__constant_1024x10xf32 : memref<1024x10xf32> loc(#loc21)
    %1 = memref.get_global @__constant_6272x1024xf32 : memref<6272x1024xf32> loc(#loc21)
    %2 = memref.get_global @__constant_10xf32 : memref<10xf32> loc(#loc)
    %3 = memref.get_global @__constant_1024xf32 : memref<1024xf32> loc(#loc)
    %4 = memref.get_global @__constant_32xf32 : memref<32xf32> loc(#loc)
    %5 = memref.get_global @__constant_32x16x3x3xf32 : memref<32x16x3x3xf32> loc(#loc)
    %6 = memref.get_global @__constant_16xf32 : memref<16xf32> loc(#loc)
    %7 = memref.get_global @__constant_16x1x3x3xf32 : memref<16x1x3x3xf32> loc(#loc)
    %alloc = memref.alloc() {alignment = 64 : i64} : memref<64x1x30x30xf32> loc(#loc26)
    linalg.fill ins(%cst_0 : f32) outs(%alloc : memref<64x1x30x30xf32>) loc(#loc26)
    %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<64x1x30x30xf32> loc(#loc26)
    linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%alloc : memref<64x1x30x30xf32>) outs(%alloc_1 : memref<64x1x30x30xf32>) {
    ^bb0(%in: f32 loc(callsite(#loc23 at #loc6)), %out: f32 loc(callsite(#loc23 at #loc6))):
      linalg.yield %in : f32 loc(#loc26)
    } loc(#loc26)
    memref.dealloc %alloc : memref<64x1x30x30xf32> loc(#loc26)
    %subview = memref.subview %alloc_1[0, 0, 1, 1] [64, 1, 28, 28] [1, 1, 1, 1] : memref<64x1x30x30xf32> to memref<64x1x28x28xf32, strided<[900, 900, 30, 1], offset: 31>> loc(#loc26)
    linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%cast : memref<64x1x28x28xf32>) outs(%subview : memref<64x1x28x28xf32, strided<[900, 900, 30, 1], offset: 31>>) {
    ^bb0(%in: f32 loc(unknown), %out: f32 loc(callsite(#loc23 at #loc6))):
      linalg.yield %in : f32 loc(#loc26)
    } loc(#loc26)
    %alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<64x16x28x28xf32> loc(#loc26)
    linalg.generic {indexing_maps = [#map1, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%6 : memref<16xf32>) outs(%alloc_2 : memref<64x16x28x28xf32>) {
    ^bb0(%in: f32 loc(callsite(#loc23 at #loc6)), %out: f32 loc(callsite(#loc23 at #loc6))):
      linalg.yield %in : f32 loc(#loc26)
    } loc(#loc26)
    %alloc_3 = memref.alloc() {alignment = 64 : i64} : memref<64x16x28x28xf32> loc(#loc26)
    linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%alloc_2 : memref<64x16x28x28xf32>) outs(%alloc_3 : memref<64x16x28x28xf32>) {
    ^bb0(%in: f32 loc(callsite(#loc23 at #loc6)), %out: f32 loc(callsite(#loc23 at #loc6))):
      linalg.yield %in : f32 loc(#loc26)
    } loc(#loc26)
    memref.dealloc %alloc_2 : memref<64x16x28x28xf32> loc(#loc26)
    linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%alloc_1, %7 : memref<64x1x30x30xf32>, memref<16x1x3x3xf32>) outs(%alloc_3 : memref<64x16x28x28xf32>) loc(#loc26)
    memref.dealloc %alloc_1 : memref<64x1x30x30xf32> loc(#loc26)
    %alloc_4 = memref.alloc() {alignment = 64 : i64} : memref<64x16x28x28xf32> loc(#loc27)
    linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%alloc_3 : memref<64x16x28x28xf32>) outs(%alloc_4 : memref<64x16x28x28xf32>) {
    ^bb0(%in: f32 loc(callsite(#loc23 at #loc6)), %out: f32 loc(callsite(#loc24 at #loc6))):
      %8 = arith.cmpf ugt, %in, %cst_0 : f32 loc(#loc27)
      %9 = arith.select %8, %in, %cst_0 : f32 loc(#loc27)
      linalg.yield %9 : f32 loc(#loc27)
    } loc(#loc27)
    memref.dealloc %alloc_3 : memref<64x16x28x28xf32> loc(#loc26)
    %alloc_5 = memref.alloc() {alignment = 64 : i64} : memref<64x16x30x30xf32> loc(#loc26)
    linalg.fill ins(%cst_0 : f32) outs(%alloc_5 : memref<64x16x30x30xf32>) loc(#loc26)
    %alloc_6 = memref.alloc() {alignment = 64 : i64} : memref<64x16x30x30xf32> loc(#loc26)
    linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%alloc_5 : memref<64x16x30x30xf32>) outs(%alloc_6 : memref<64x16x30x30xf32>) {
    ^bb0(%in: f32 loc(callsite(#loc23 at #loc6)), %out: f32 loc(callsite(#loc23 at #loc6))):
      linalg.yield %in : f32 loc(#loc26)
    } loc(#loc26)
    memref.dealloc %alloc_5 : memref<64x16x30x30xf32> loc(#loc26)
    %subview_7 = memref.subview %alloc_6[0, 0, 1, 1] [64, 16, 28, 28] [1, 1, 1, 1] : memref<64x16x30x30xf32> to memref<64x16x28x28xf32, strided<[14400, 900, 30, 1], offset: 31>> loc(#loc26)
    linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%alloc_4 : memref<64x16x28x28xf32>) outs(%subview_7 : memref<64x16x28x28xf32, strided<[14400, 900, 30, 1], offset: 31>>) {
    ^bb0(%in: f32 loc(callsite(#loc24 at #loc6)), %out: f32 loc(callsite(#loc23 at #loc6))):
      linalg.yield %in : f32 loc(#loc26)
    } loc(#loc26)
    memref.dealloc %alloc_4 : memref<64x16x28x28xf32> loc(#loc27)
    %alloc_8 = memref.alloc() {alignment = 64 : i64} : memref<64x32x28x28xf32> loc(#loc26)
    linalg.generic {indexing_maps = [#map1, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%4 : memref<32xf32>) outs(%alloc_8 : memref<64x32x28x28xf32>) {
    ^bb0(%in: f32 loc(callsite(#loc23 at #loc6)), %out: f32 loc(callsite(#loc23 at #loc6))):
      linalg.yield %in : f32 loc(#loc26)
    } loc(#loc26)
    %alloc_9 = memref.alloc() {alignment = 64 : i64} : memref<64x32x28x28xf32> loc(#loc26)
    linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%alloc_8 : memref<64x32x28x28xf32>) outs(%alloc_9 : memref<64x32x28x28xf32>) {
    ^bb0(%in: f32 loc(callsite(#loc23 at #loc6)), %out: f32 loc(callsite(#loc23 at #loc6))):
      linalg.yield %in : f32 loc(#loc26)
    } loc(#loc26)
    memref.dealloc %alloc_8 : memref<64x32x28x28xf32> loc(#loc26)
    linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%alloc_6, %5 : memref<64x16x30x30xf32>, memref<32x16x3x3xf32>) outs(%alloc_9 : memref<64x32x28x28xf32>) loc(#loc26)
    memref.dealloc %alloc_6 : memref<64x16x30x30xf32> loc(#loc26)
    %alloc_10 = memref.alloc() {alignment = 64 : i64} : memref<64x32x28x28xf32> loc(#loc27)
    linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%alloc_9 : memref<64x32x28x28xf32>) outs(%alloc_10 : memref<64x32x28x28xf32>) {
    ^bb0(%in: f32 loc(callsite(#loc23 at #loc6)), %out: f32 loc(callsite(#loc24 at #loc6))):
      %8 = arith.cmpf ugt, %in, %cst_0 : f32 loc(#loc27)
      %9 = arith.select %8, %in, %cst_0 : f32 loc(#loc27)
      linalg.yield %9 : f32 loc(#loc27)
    } loc(#loc27)
    memref.dealloc %alloc_9 : memref<64x32x28x28xf32> loc(#loc26)
    %alloc_11 = memref.alloc() {alignment = 64 : i64} : memref<64x32x14x14xf32> loc(#loc25)
    linalg.fill ins(%cst : f32) outs(%alloc_11 : memref<64x32x14x14xf32>) loc(#loc25)
    %alloc_12 = memref.alloc() {alignment = 64 : i64} : memref<2x2xf32> loc(#loc25)
    %alloc_13 = memref.alloc() {alignment = 64 : i64} : memref<64x32x14x14xf32> loc(#loc25)
    linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%alloc_11 : memref<64x32x14x14xf32>) outs(%alloc_13 : memref<64x32x14x14xf32>) {
    ^bb0(%in: f32 loc(callsite(#loc22 at #loc6)), %out: f32 loc(callsite(#loc22 at #loc6))):
      linalg.yield %in : f32 loc(#loc25)
    } loc(#loc25)
    memref.dealloc %alloc_11 : memref<64x32x14x14xf32> loc(#loc25)
    linalg.pooling_nchw_max {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>} ins(%alloc_10, %alloc_12 : memref<64x32x28x28xf32>, memref<2x2xf32>) outs(%alloc_13 : memref<64x32x14x14xf32>) loc(#loc25)
    memref.dealloc %alloc_12 : memref<2x2xf32> loc(#loc25)
    memref.dealloc %alloc_10 : memref<64x32x28x28xf32> loc(#loc27)
    %collapse_shape = memref.collapse_shape %alloc_13 [[0], [1, 2, 3]] : memref<64x32x14x14xf32> into memref<64x6272xf32> loc(#loc16)
    %alloc_14 = memref.alloc() {alignment = 64 : i64} : memref<64x1024xf32> loc(#loc21)
    linalg.fill ins(%cst_0 : f32) outs(%alloc_14 : memref<64x1024xf32>) loc(#loc21)
    %alloc_15 = memref.alloc() {alignment = 64 : i64} : memref<64x1024xf32> loc(#loc21)
    linalg.generic {indexing_maps = [#map2, #map2], iterator_types = ["parallel", "parallel"]} ins(%alloc_14 : memref<64x1024xf32>) outs(%alloc_15 : memref<64x1024xf32>) {
    ^bb0(%in: f32 loc(callsite(#loc17 at #loc3)), %out: f32 loc(callsite(#loc17 at #loc3))):
      linalg.yield %in : f32 loc(#loc21)
    } loc(#loc21)
    memref.dealloc %alloc_14 : memref<64x1024xf32> loc(#loc21)
    linalg.matmul ins(%collapse_shape, %1 : memref<64x6272xf32>, memref<6272x1024xf32>) outs(%alloc_15 : memref<64x1024xf32>) loc(#loc21)
    memref.dealloc %alloc_13 : memref<64x32x14x14xf32> loc(#loc25)
    %alloc_16 = memref.alloc() {alignment = 64 : i64} : memref<64x1024xf32> loc(#loc28)
    linalg.generic {indexing_maps = [#map2, #map3, #map2], iterator_types = ["parallel", "parallel"]} ins(%alloc_15, %3 : memref<64x1024xf32>, memref<1024xf32>) outs(%alloc_16 : memref<64x1024xf32>) {
    ^bb0(%in: f32 loc(callsite(#loc17 at #loc3)), %in_21: f32 loc(callsite(#loc17 at #loc3)), %out: f32 loc(callsite(#loc24 at #loc3))):
      %8 = arith.addf %in, %in_21 : f32 loc(#loc21)
      %9 = arith.cmpf ugt, %8, %cst_0 : f32 loc(#loc28)
      %10 = arith.select %9, %8, %cst_0 : f32 loc(#loc28)
      linalg.yield %10 : f32 loc(#loc28)
    } loc(#loc28)
    memref.dealloc %alloc_15 : memref<64x1024xf32> loc(#loc21)
    %alloc_17 = memref.alloc() {alignment = 64 : i64} : memref<64x10xf32> loc(#loc21)
    linalg.fill ins(%cst_0 : f32) outs(%alloc_17 : memref<64x10xf32>) loc(#loc21)
    %alloc_18 = memref.alloc() {alignment = 64 : i64} : memref<64x10xf32> loc(#loc21)
    linalg.generic {indexing_maps = [#map2, #map2], iterator_types = ["parallel", "parallel"]} ins(%alloc_17 : memref<64x10xf32>) outs(%alloc_18 : memref<64x10xf32>) {
    ^bb0(%in: f32 loc(callsite(#loc17 at #loc3)), %out: f32 loc(callsite(#loc17 at #loc3))):
      linalg.yield %in : f32 loc(#loc21)
    } loc(#loc21)
    memref.dealloc %alloc_17 : memref<64x10xf32> loc(#loc21)
    linalg.matmul ins(%alloc_16, %0 : memref<64x1024xf32>, memref<1024x10xf32>) outs(%alloc_18 : memref<64x10xf32>) loc(#loc21)
    memref.dealloc %alloc_16 : memref<64x1024xf32> loc(#loc28)
    %alloc_19 = memref.alloc() {alignment = 64 : i64} : memref<64x10xf32> loc(#loc21)
    linalg.generic {indexing_maps = [#map2, #map3, #map2], iterator_types = ["parallel", "parallel"]} ins(%alloc_18, %2 : memref<64x10xf32>, memref<10xf32>) outs(%alloc_19 : memref<64x10xf32>) {
    ^bb0(%in: f32 loc(callsite(#loc17 at #loc3)), %in_21: f32 loc(callsite(#loc17 at #loc3)), %out: f32 loc(callsite(#loc17 at #loc3))):
      %8 = arith.addf %in, %in_21 : f32 loc(#loc21)
      linalg.yield %8 : f32 loc(#loc21)
    } loc(#loc21)
    memref.dealloc %alloc_18 : memref<64x10xf32> loc(#loc21)
    %cast_20 = memref.cast %alloc_19 : memref<64x10xf32> to memref<*xf32> loc(#loc)
    call @refbackend_consume_func_return_mrf32(%cast_20) : (memref<*xf32>) -> () loc(#loc)
    return loc(#loc)
  } loc(#loc)
} loc(#loc)
#loc11 = loc("/scratch/qjiang/work/final_grad_prj/playground/MyPyApp/my_torch_jit_app/my_torch_jit_app1.py":96:12)
#loc16 = loc("aten::view"(#loc11))
jiangqucheng commented 9 months ago

Test of the pass rewriter workflow - Cont.

Changes to Hwacc.cpp

After changing casting Op to MatmulOp, this ` function acts as the filter to capture the exactlinalg.matmul` operation.

MLIR Script before and after (convert-linalg-to-hwacc)

image