iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.86k stars 622 forks source link

Support rank-reduction slices in bubble-up extract slice #19173

Closed MaheshRavishankar closed 1 week ago

MaheshRavishankar commented 1 week ago

Currently bubble up extract slice does not support rank-reducing slices. So this repro leads to a lot of unnecessary computation

module {
  func.func @repro(%arg0: tensor<131072xi64>, %arg1: tensor<1x1x131072xi64>, %arg2: index) -> tensor<?x?xi1> {
    %0 = tensor.empty() : tensor<1x1x131072x131072xi1>
    %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "\parallel", "parallel", "parallel"]} ins(%arg0, %arg1 : tensor<131072xi64>, tensor<1x1x131072xi64>) outs(%0 : tensor<1x1x131072x131072xi1>) {
    ^bb0(%in: i64, %in_0: i64, %out: i1):
      %2 = arith.cmpi sge, %in, %in_0 : i64
      linalg.yield %2 : i1
    } -> tensor<1x1x131072x131072xi1>
    %extracted_slice = tensor.extract_slice %1[0, 0, 0, 0] [1, 1, %arg2, %arg2] [1, 1, 1, 1] : tensor<1x1x131072x131072xi1> to tensor<?x?xi1>
    return %extracted_slice : tensor<?x?xi1>
  }
}