Closed MaheshRavishankar closed 2 years ago
After fixing those, using
%> iree-opt -iree-flow-fusion-of-tensor-ops repro.mlir
will result in
#map = affine_map<(d0) -> (d0)>
module {
func @forward(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view) -> (!hal.buffer_view, !hal.buffer_view) attributes {iree.abi.stub} {
%c42_i32 = arith.constant 42 : i32
%c5000_i32 = arith.constant 5000 : i32
%c1_i32 = arith.constant 1 : i32
%cst = arith.constant 2.000000e-04 : f64
%cst_0 = arith.constant 1.000000e+00 : f32
%c5000 = arith.constant 5000 : index
%cst_1 = arith.constant 4.000000e-01 : f32
%c0_i32 = arith.constant 0 : i32
%0 = hal.tensor.import %arg0 : !hal.buffer_view -> tensor<5000xi32>
%1 = hal.tensor.import %arg1 : !hal.buffer_view -> tensor<5000x1xi32>
%2 = hal.tensor.import %arg2 : !hal.buffer_view -> tensor<5000x1xf32>
%3 = tensor.collapse_shape %2 [[0, 1]] : tensor<5000x1xf32> into tensor<5000xf32>
%4 = linalg.init_tensor [5000] : tensor<5000xf32>
%5 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%3 : tensor<5000xf32>) outs(%4 : tensor<5000xf32>) {
^bb0(%arg3: f32, %arg4: f32): // no predecessors
%12 = math.log %cst_1 : f32
%13 = arith.mulf %12, %cst_0 : f32
%14 = arith.addf %arg3, %13 : f32
%15 = arith.negf %14 : f32
%16 = math.exp %15 : f32
%17 = arith.addf %16, %cst_0 : f32
%18 = arith.divf %cst_0, %17 : f32
linalg.yield %18 : f32
} -> tensor<5000xf32>
%6 = tensor.collapse_shape %1 [[0, 1]] : tensor<5000x1xi32> into tensor<5000xi32>
%7 = linalg.init_tensor [5000] : tensor<5000xi64>
%8 = tensor.cast %7 : tensor<5000xi64> to tensor<?xi64>
%9 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%5, %6 : tensor<5000xf32>, tensor<5000xi32>) outs(%8 : tensor<?xi64>) {
^bb0(%arg3: f32, %arg4: i32, %arg5: i64): // no predecessors
%12 = linalg.index 0 : index
%13 = arith.index_cast %12 : index to i64
%14 = arith.index_cast %13 : i64 to index
%15 = tensor.extract %0[%14] : tensor<5000xi32>
%16 = linalg.index 0 : index
%17 = arith.addi %15, %c1_i32 : i32
%18 = arith.index_cast %13 : i64 to index
%19 = arith.cmpi eq, %18, %16 : index
%20 = select %19, %17, %c0_i32 : i32
%21 = arith.cmpi sgt, %20, %c42_i32 : i32
%22 = select %21, %c0_i32, %20 : i32
%23 = arith.cmpi sgt, %c0_i32, %20 : i32
%24 = select %23, %c0_i32, %22 : i32
%25 = arith.cmpi eq, %arg4, %c1_i32 : i32
%26 = arith.truncf %cst : f64 to f32
%27 = arith.divf %arg3, %26 : f32
%28 = select %25, %24, %c0_i32 : i32
%29 = math.ceil %27 : f32
%30 = arith.subf %29, %cst_0 : f32
%31 = arith.muli %28, %c5000_i32 : i32
%32 = arith.fptosi %30 : f32 to i64
%33 = arith.extsi %31 : i32 to i64
%34 = arith.addi %32, %33 : i64
linalg.yield %34 : i64
} -> tensor<?xi64>
%10 = hal.tensor.export %5 : tensor<?xf32> as tensor<5000xf32>{%c5000} -> !hal.buffer_view
%11 = hal.tensor.export %9 : tensor<?xi64>{%c5000} -> !hal.buffer_view
return %10, %11 : !hal.buffer_view, !hal.buffer_view
}
}
The two generic ops can be fused further as well if the MLIR elementwise fusion is adapted to fuse operations that result in multiple outputs. That has more fallout that needs to be worked through, but the canonicalizations mentioned above seems like reasonably simple things to add.
This input IR shows some missing canonicalizations that can help fusion
1) This sequence of operations
can be replaced with
2) Next this sequence
can be replaced with