iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.56k stars 572 forks source link

Missing canonicalizations for better fusion #8129

Closed MaheshRavishankar closed 2 years ago

MaheshRavishankar commented 2 years ago

This input IR shows some missing canonicalizations that can help fusion

func @forward(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view) -> (!hal.buffer_view, !hal.buffer_view) attributes {iree.abi.stub} {
  %c42_i32 = arith.constant 42 : i32
  %cst = arith.constant 0.000000e+00 : f64
  %c5000_i32 = arith.constant 5000 : i32
  %c1_i32 = arith.constant 1 : i32
  %c0_i32 = arith.constant 0 : i32
  %cst_0 = arith.constant 5.000000e-04 : f64
  %cst_1 = arith.constant 9.995000e-01 : f64
  %cst_2 = arith.constant 2.000000e-04 : f64
  %cst_3 = arith.constant dense<4.000000e-01> : tensor<f32>
  %cst_4 = arith.constant 1.000000e+00 : f32
  %c5000 = arith.constant 5000 : index
  %true = arith.constant true
  %0 = hal.tensor.import %arg0 : !hal.buffer_view -> tensor<5000xi32>
  %1 = hal.tensor.import %arg1 : !hal.buffer_view -> tensor<5000x1xi32>
  %2 = hal.tensor.import %arg2 : !hal.buffer_view -> tensor<5000x1xf32>
  %3 = linalg.init_tensor [] : tensor<f32>
  %4 = linalg.generic {indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>], iterator_types = []} ins(%cst_3 : tensor<f32>) outs(%3 : tensor<f32>) {
  ^bb0(%arg3: f32, %arg4: f32):  // no predecessors
    %61 = math.log %arg3 : f32
    linalg.yield %61 : f32
  } -> tensor<f32>
  %5 = tensor.collapse_shape %2 [[0, 1]] : tensor<5000x1xf32> into tensor<5000xf32>
  %6 = linalg.init_tensor [5000] : tensor<5000xf32>
  %7 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> ()>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%5, %4 : tensor<5000xf32>, tensor<f32>) outs(%6 : tensor<5000xf32>) {
  ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):  // no predecessors
    %61 = arith.mulf %arg4, %cst_4 : f32
    %62 = arith.addf %arg3, %61 : f32
    linalg.yield %62 : f32
  } -> tensor<5000xf32>
  %8 = linalg.init_tensor [5000] : tensor<5000xf32>
  %9 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%7 : tensor<5000xf32>) outs(%8 : tensor<5000xf32>) {
  ^bb0(%arg3: f32, %arg4: f32):  // no predecessors
    %61 = arith.negf %arg3 : f32
    %62 = math.exp %61 : f32
    %63 = arith.addf %62, %cst_4 : f32
    %64 = arith.divf %cst_4, %63 : f32
    linalg.yield %64 : f32
  } -> tensor<5000xf32>
  %10 = tensor.expand_shape %9 [[0, 1]] : tensor<5000xf32> into tensor<5000x1xf32>
  %11 = tensor.cast %10 : tensor<5000x1xf32> to tensor<?x?xf32>
  %12 = tensor.collapse_shape %11 [[0, 1]] : tensor<?x?xf32> into tensor<?xf32>
  %13 = linalg.init_tensor [5000] : tensor<5000xi32>
  %14 = linalg.fill(%c0_i32, %13) : i32, tensor<5000xi32> -> tensor<5000xi32> 
  %15 = linalg.init_tensor [5000] : tensor<5000xi64>
  %16 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} outs(%15 : tensor<5000xi64>) {
  ^bb0(%arg3: i64):  // no predecessors
    %61 = linalg.index 0 : index
    %62 = arith.index_cast %61 : index to i64
    linalg.yield %62 : i64
  } -> tensor<5000xi64>
  %17 = linalg.init_tensor [%c5000] : tensor<?xi32>
  %18 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%16 : tensor<5000xi64>) outs(%17 : tensor<?xi32>) {
  ^bb0(%arg3: i64, %arg4: i32):  // no predecessors
    %61 = arith.index_cast %arg3 : i64 to index
    %62 = tensor.extract %0[%61] : tensor<5000xi32>
    linalg.yield %62 : i32
  } -> tensor<?xi32>
  %19 = linalg.init_tensor [%c5000] : tensor<?xi32>
  %20 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%18 : tensor<?xi32>) outs(%19 : tensor<?xi32>) {
  ^bb0(%arg3: i32, %arg4: i32):  // no predecessors
    %61 = arith.addi %arg3, %c1_i32 : i32
    linalg.yield %61 : i32
  } -> tensor<?xi32>
  %21 = linalg.generic {indexing_maps = [affine_map<(d0 -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0, d0) -> (d0)>], iterator_types = ["parallel", "parallel"]} ins(%16, %14, %20 : tensor<5000xi64>, tensor<5000xi32>, tensor<?xi32>) outs(%14 : tensor<5000xi32>) {
  ^bb0(%arg3: i64, %arg4: i32, %arg5: i32, %arg6: i32):  // no predecessors
    %61 = arith.index_cast %arg3 : i64 to index
    %62 = linalg.index 1 : index
    %63 = arith.cmpi eq, %61, %62 : index
    %64 = select %63, %arg5, %arg6 : i32
    linalg.yield %64 : i32
  } -> tensor<5000xi32>
  %22 = tensor.collapse_shape %1 [[0, 1]] : tensor<5000x1xi32> into tensor<5000xi32>
  %23 = linalg.fill(%c0_i32, %13) : i32, tensor<5000xi32> -> tensor<5000xi32> 
  %24 = linalg.fill(%c1_i32, %13) : i32, tensor<5000xi32> -> tensor<5000xi32> 
  %25 = linalg.init_tensor [5000] : tensor<5000xi1>
  %26 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%21 : tensor<5000xi32>) outs(%25 : tensor<5000xi1>) {
  ^bb0(%arg3: i32, %arg4: i1):  // no predecessors
    %61 = arith.cmpi sgt, %arg3, %c42_i32 : i32
    linalg.yield %61 : i1
  } -> tensor<5000xi1>
  %27 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%23, %21 : tensor<5000xi32>, tensor<5000xi32>) outs(%25 : tensor<5000xi1>) {
  ^bb0(%arg3: i32, %arg4: i32, %arg5: i1):  // no predecessors
    %61 = arith.cmpi sgt, %arg3, %arg4 : i32
    linalg.yield %61 : i1
  } -> tensor<5000xi1>
  %28 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%22, %24 : tensor<5000xi32>, tensor<5000xi32>) outs(%25 : tensor<5000xi1>) {
  ^bb0(%arg3: i32, %arg4: i32, %arg5: i1):  // no predecessors
    %61 = arith.cmpi eq, %arg3, %arg4 : i32
    linalg.yield %61 : i1
  } -> tensor<5000xi1>
  %29 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%26, %23, %21 : tensor<5000xi1>, tensor<5000xi32>, tensor<5000xi32>) outs(%13 : tensor<5000xi32>) {
  ^bb0(%arg3: i1, %arg4: i32, %arg5: i32, %arg6: i32):  // no predecessors
    %61 = select %arg3, %arg4, %arg5 : i32
    linalg.yield %61 : i32
  } -> tensor<5000xi32>
  %30 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%27, %23, %29 : tensor<5000xi1>, tensor<5000xi32>, tensor<5000xi32>) outs(%13 : tensor<5000xi32>) {
  ^bb0(%arg3: i1, %arg4: i32, %arg5: i32, %arg6: i32):  // no predecessors
    %61 = select %arg3, %arg4, %arg5 : i32
    linalg.yield %61 : i32
  } -> tensor<5000xi32>
  %31 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%28, %30, %23 : tensor<5000xi1>, tensor<5000xi32>, tensor<5000xi32>) outs(%13 : tensor<5000xi32>) {
  ^bb0(%arg3: i1, %arg4: i32, %arg5: i32, %arg6: i32):  // no predecessors
    %61 = select %arg3, %arg4, %arg5 : i32
    linalg.yield %61 : i32
  } -> tensor<5000xi32>
  %32 = linalg.init_tensor [%c5000] : tensor<?xf32>
  %33 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%12 : tensor<?xf32>) outs(%32 : tensor<?xf32>) {
  ^bb0(%arg3: f32, %arg4: f32):  // no predecessors
    %61 = arith.truncf %cst_2 : f64 to f32
    %62 = arith.divf %arg3, %61 : f32
    linalg.yield %62 : f32
  } -> tensor<?xf32>
  %34 = linalg.init_tensor [%c5000] : tensor<?xf32>
  %35 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%33 : tensor<?xf32>) outs(%34 : tensor<?xf32>) {
  ^bb0(%arg3: f32, %arg4: f32):  // no predecessors
    %61 = math.ceil %arg3 : f32
    linalg.yield %61 : f32
  } -> tensor<?xf32>
  %36 = linalg.init_tensor [%c5000] : tensor<?xf32>
  %37 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%35 : tensor<?xf32>) outs(%36 : tensor<?xf32>) {
  ^bb0(%arg3: f32, %arg4: f32):  // no predecessors
    %61 = arith.subf %arg3, %cst_4 : f32
    linalg.yield %61 : f32
  } -> tensor<?xf32>
  %38 = linalg.init_tensor [%c5000] : tensor<?xi64>
  %39 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%37 : tensor<?xf32>) outs(%38 : tensor<?xi64>) {
  ^bb0(%arg3: f32, %arg4: i64):  // no predecessors
    %61 = arith.fptosi %arg3 : f32 to i64
    linalg.yield %61 : i64
  } -> tensor<?xi64>
  %40 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%31 : tensor<5000xi32>) outs(%13 : tensor<5000xi32>) {
  ^bb0(%arg3: i32, %arg4: i32):  // no predecessors
    %61 = arith.muli %arg3, %c5000_i32 : i32
    linalg.yield %61 : i32
  } -> tensor<5000xi32>
  assert %true, "mismatched size for broadcast"
  %41 = linalg.init_tensor [%c5000] : tensor<?xi64>
  %42 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%39, %40 : tensor<?xi64>, tensor<5000xi32>) outs(%41 : tensor<?xi64>) {
  ^bb0(%arg3: i64, %arg4: i32, %arg5: i64):  // no predecessors
    %61 = arith.extsi %arg4 : i32 to i64
    %62 = arith.addi %arg3, %61 : i64
    linalg.yield %62 : i64
  } -> tensor<?xi64>
  %43 = linalg.init_tensor [%c5000] : tensor<?xf64>
  %44 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%42 : tensor<?xi64>) outs(%43 : tensor<?xf64>) {
  ^bb0(%arg3: i64, %arg4: f64):  // no predecessors
    linalg.yield %cst : f64
  } -> tensor<?xf64>
  %45 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%42 : tensor<?xi64>) outs(%43 : tensor<?xf64>) {
  ^bb0(%arg3: i64, %arg4: f64):  // no predecessors
    linalg.yield %cst : f64
  } -> tensor<?xf64>
  assert %true, "mismatched size for broadcast"
  %46 = linalg.init_tensor [%c5000] : tensor<?xf64>
  %47 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%44, %45 : tensor<?xf64>, tensor<?xf64>) outs(%46 : tensor<?xf64>) {
  ^bb0(%arg3: f64, %arg4: f64, %arg5: f64):  // no predecessors
    %61 = arith.divf %arg3, %arg4 : f64
    linalg.yield %61 : f64
  } -> tensor<?xf64>
  %48 = linalg.init_tensor [%c5000] : tensor<?xf32>
  %49 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%47 : tensor<?xf64>) outs(%48 : tensor<?xf32>) {
  ^bb0(%arg3: f64, %arg4: f32):  // no predecessors
    %61 = arith.truncf %arg3 : f64 to f32
    linalg.yield %61 : f32
  } -> tensor<?xf32>
  %50 = linalg.init_tensor [%c5000] : tensor<?xf32>
  %51 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%49 : tensor<?xf32>) outs(%50 : tensor<?xf32>) {
  ^bb0(%arg3: f32, %arg4: f32):  // no predecessors
    %61 = arith.truncf %cst_1 : f64 to f32
    %62 = arith.mulf %arg3, %61 : f32
    linalg.yield %62 : f32
  } -> tensor<?xf32>
  %52 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%12 : tensor<?xf32>) outs(%32 : tensor<?xf32>) {
  ^bb0(%arg3: f32, %arg4: f32):  // no predecessors
    %61 = arith.truncf %cst_0 : f64 to f32
    %62 = arith.mulf %arg3, %61 : f32
    linalg.yield %62 : f32
  } -> tensor<?xf32>
  assert %true, "mismatched size for broadcast"
  %53 = linalg.init_tensor [%c5000] : tensor<?xf32>
  %54 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%51, %52 : tensor<?xf32>, tensor<?xf32>) outs(%53 : tensor<?xf32>) {
  ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):  // no predecessors
    %61 = arith.mulf %arg4, %cst_4 : f32
    %62 = arith.addf %arg3, %61 : f32
    linalg.yield %62 : f32
  } -> tensor<?xf32>
  %55 = linalg.init_tensor [%c5000] : tensor<?xi1>
  %56 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%45 : tensor<?xf64>) outs(%55 : tensor<?xi1>) {
  ^bb0(%arg3: f64, %arg4: i1):  // no predecessors
    %61 = arith.cmpf ugt, %arg3, %cst : f64
    linalg.yield %61 : i1
  } -> tensor<?xi1>
  assert %true, "mismatched size for broadcast"
  assert %true, "mismatched size for broadcast"
  %57 = linalg.init_tensor [%c5000] : tensor<?xf32>
  %58 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%56, %54, %12 : tensor<?xi1>, tensor<?xf32>, tensor<?xf32>) outs(%57 : tensor<?xf32>) {
  ^bb0(%arg3: i1, %arg4: f32, %arg5: f32, %arg6: f32):  // no predecessors
    %61 = select %arg3, %arg4, %arg5 : f32
    linalg.yield %61 : f32
  } -> tensor<?xf32>
  %59 = hal.tensor.export %58 : tensor<?xf32>{%c5000} -> !hal.buffer_view
  %60 = hal.tensor.export %42 : tensor<?xi64>{%c5000} -> !hal.buffer_view
  return %59, %60 : !hal.buffer_view, !hal.buffer_view
}

1) This sequence of operations

 %10 = tensor.expand_shape %9 [[0, 1]] : tensor<5000xf32> into tensor<5000x1xf32>
  %11 = tensor.cast %10 : tensor<5000x1xf32> to tensor<?x?xf32>
  %12 = tensor.collapse_shape %11 [[0, 1]] : tensor<?x?xf32> into tensor<?xf32>

can be replaced with

%12 = tensor.cast %9 : tensor<5000xf32> into tensor<?xf32>

2) Next this sequence

%13 = linalg.init_tensor [5000] : tensor<5000xi32>
  %14 = linalg.fill(%c0_i32, %13) : i32, tensor<5000xi32> -> tensor<5000xi32> 

can be replaced with

%14 = arith.constant dense<0> : tensor<5000xi32>
MaheshRavishankar commented 2 years ago

After fixing those, using

%> iree-opt -iree-flow-fusion-of-tensor-ops repro.mlir

will result in

#map = affine_map<(d0) -> (d0)>
module  {
  func @forward(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view) -> (!hal.buffer_view, !hal.buffer_view) attributes {iree.abi.stub} {
    %c42_i32 = arith.constant 42 : i32
    %c5000_i32 = arith.constant 5000 : i32
    %c1_i32 = arith.constant 1 : i32
    %cst = arith.constant 2.000000e-04 : f64
    %cst_0 = arith.constant 1.000000e+00 : f32
    %c5000 = arith.constant 5000 : index
    %cst_1 = arith.constant 4.000000e-01 : f32
    %c0_i32 = arith.constant 0 : i32
    %0 = hal.tensor.import %arg0 : !hal.buffer_view -> tensor<5000xi32>
    %1 = hal.tensor.import %arg1 : !hal.buffer_view -> tensor<5000x1xi32>
    %2 = hal.tensor.import %arg2 : !hal.buffer_view -> tensor<5000x1xf32>
    %3 = tensor.collapse_shape %2 [[0, 1]] : tensor<5000x1xf32> into tensor<5000xf32>
    %4 = linalg.init_tensor [5000] : tensor<5000xf32>
    %5 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%3 : tensor<5000xf32>) outs(%4 : tensor<5000xf32>) {
    ^bb0(%arg3: f32, %arg4: f32):  // no predecessors
      %12 = math.log %cst_1 : f32
      %13 = arith.mulf %12, %cst_0 : f32
      %14 = arith.addf %arg3, %13 : f32
      %15 = arith.negf %14 : f32
      %16 = math.exp %15 : f32
      %17 = arith.addf %16, %cst_0 : f32
      %18 = arith.divf %cst_0, %17 : f32
      linalg.yield %18 : f32
    } -> tensor<5000xf32>
    %6 = tensor.collapse_shape %1 [[0, 1]] : tensor<5000x1xi32> into tensor<5000xi32>
    %7 = linalg.init_tensor [5000] : tensor<5000xi64>
    %8 = tensor.cast %7 : tensor<5000xi64> to tensor<?xi64>
    %9 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%5, %6 : tensor<5000xf32>, tensor<5000xi32>) outs(%8 : tensor<?xi64>) {
    ^bb0(%arg3: f32, %arg4: i32, %arg5: i64):  // no predecessors
      %12 = linalg.index 0 : index
      %13 = arith.index_cast %12 : index to i64
      %14 = arith.index_cast %13 : i64 to index
      %15 = tensor.extract %0[%14] : tensor<5000xi32>
      %16 = linalg.index 0 : index
      %17 = arith.addi %15, %c1_i32 : i32
      %18 = arith.index_cast %13 : i64 to index
      %19 = arith.cmpi eq, %18, %16 : index
      %20 = select %19, %17, %c0_i32 : i32
      %21 = arith.cmpi sgt, %20, %c42_i32 : i32
      %22 = select %21, %c0_i32, %20 : i32
      %23 = arith.cmpi sgt, %c0_i32, %20 : i32
      %24 = select %23, %c0_i32, %22 : i32
      %25 = arith.cmpi eq, %arg4, %c1_i32 : i32
      %26 = arith.truncf %cst : f64 to f32
      %27 = arith.divf %arg3, %26 : f32
      %28 = select %25, %24, %c0_i32 : i32
      %29 = math.ceil %27 : f32
      %30 = arith.subf %29, %cst_0 : f32
      %31 = arith.muli %28, %c5000_i32 : i32
      %32 = arith.fptosi %30 : f32 to i64
      %33 = arith.extsi %31 : i32 to i64
      %34 = arith.addi %32, %33 : i64
      linalg.yield %34 : i64
    } -> tensor<?xi64>
    %10 = hal.tensor.export %5 : tensor<?xf32> as tensor<5000xf32>{%c5000} -> !hal.buffer_view
    %11 = hal.tensor.export %9 : tensor<?xi64>{%c5000} -> !hal.buffer_view
    return %10, %11 : !hal.buffer_view, !hal.buffer_view
  }
}

The two generic ops can be fused further as well if the MLIR elementwise fusion is adapted to fuse operations that result in multiple outputs. That has more fallout that needs to be worked through, but the canonicalizations mentioned above seems like reasonably simple things to add.

MaheshRavishankar commented 2 years ago

9303 tracks the work w.r.t fusion to create multi-result operations. Closing this one.