iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.85k stars 614 forks source link

LinalgExt ops don't support fusion #17392

Open IanWood1 opened 6 months ago

IanWood1 commented 6 months ago

IREE currently lacks support for fusing LinalgExt operations with either other LinalgExt operations or standard Linalg operations..The current fusion implementation relies on indexing maps to determine which operations can be successfully fused. See the example below:

    %3 = tensor.empty() : tensor<4x1xi32>
    %4 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel"]} ins(%expanded : tensor<4x1xi64>) outs(%3 : tensor<4x1xi32>) {
    ^bb0(%in: i64, %out: i32):
      %10 = arith.trunci %in : i64 to i32
      linalg.yield %10 : i32
    } -> tensor<4x1xi32>
    %5 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(false) ins(%expanded_0, %4 : tensor<4x1x16x8x128xf32>, tensor<4x1xi32>) outs(%2 : tensor<8192x16x8x128xf32>) {
    ^bb0(%arg5: f32, %arg6: f32):
      iree_linalg_ext.yield %arg5 : f32
    } -> tensor<8192x16x8x128xf32>

This results in the following dispatch formation (note no fusion):

  %3 = tensor.empty() : tensor<4x1xi32>
  %4 = flow.dispatch.region -> (tensor<4x1xi32>) {
    %10 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%expanded_0 : tensor<4x1xi64>) outs(%3 : tensor<4x1xi32>) {
    ^bb0(%in: i64, %out: i32):
      %11 = arith.trunci %in : i64 to i32
      linalg.yield %11 : i32
    } -> tensor<4x1xi32>
    flow.return %10 : tensor<4x1xi32>
  }
  %5 = flow.dispatch.region -> (tensor<8192x16x8x128xf32>) {
    %10 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(false) ins(%expanded, %4 : tensor<4x1x16x8x128xf32>, tensor<4x1xi32>) outs(%2 : tensor<8192x16x8x128xf32>) {
    ^bb0(%arg5: f32, %arg6: f32):
      iree_linalg_ext.yield %arg5 : f32
    } -> tensor<8192x16x8x128xf32>
    flow.return %10 : tensor<8192x16x8x128xf32>
  }

Immediate solution

Adding functionality to FormDispatchRegions to get indexing maps for specific LinalgExt ops. This would be a quick and easy way to get indexing maps for specificLinalgExt ops. Also, it would lay the groundwork for a long-term solution.

Long term solution


include "mlir/Interfaces/DestinationStyleOpInterface.td" https://github.com/iree-org/iree/blob/a78cee1f0e84e99eaca8b0ae46e2da609916c6fb/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp#L608 Linalg TilingInterfaceImpl LinalgExt TilingInterfaceImpl

MaheshRavishankar commented 6 months ago

Thanks @IanWood1 for capturing this. Just to amend your long term solution. We will probably end up adding this to TilingInterface which already has the notion of the iteration spaces. TBD though.