iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.83k stars 611 forks source link

Redundant memref.subview ops are created in bufferization #9360

Closed hanhanW closed 2 years ago

hanhanW commented 2 years ago

Found some cases that redundant memref.subview ops are created in IREE.

Input:

module {
  func.func @slice() {
    %0 = hal.interface.constant.load[0] : index
    %1 = hal.interface.constant.load[1] : index
    %2 = hal.interface.constant.load[2] : index
    %3 = hal.interface.constant.load[3] : index
    %4 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor<readonly:?x?xi32>{%0, %1}
    %5 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : !flow.dispatch.tensor<writeonly:?x?xi32>{%2, %3}
    %6 = flow.dispatch.tensor.load %4, offsets = [0, 0], sizes = [%0, %1], strides = [1, 1] : !flow.dispatch.tensor<readonly:?x?xi32>{%0, %1} -> tensor<?x?xi32>
    %7 = tensor.extract_slice %6[%0, %1] [%2, %3] [1, 1] : tensor<?x?xi32> to tensor<?x?xi32>
    flow.dispatch.tensor.store %7, %5, offsets = [0, 0], sizes = [%2, %3], strides = [1, 1] : tensor<?x?xi32> -> !flow.dispatch.tensor<writeonly:?x?xi32>{%2, %3}
    return
  }
}

Output:

#map0 = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
#map1 = affine_map<(d0, d1) -> (d0, d1)>
module {
  func.func @slice() {
    %0 = hal.interface.constant.load[0] : index
    %1 = hal.interface.constant.load[1] : index
    %2 = hal.interface.constant.load[2] : index
    %3 = hal.interface.constant.load[3] : index
    %4 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : memref<?x?xi32>{%0, %1}
    %5 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor<readonly:?x?xi32>{%0, %1}
    %6 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : memref<?x?xi32>{%2, %3}
    %7 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : !flow.dispatch.tensor<writeonly:?x?xi32>{%2, %3}
    %8 = memref.subview %4[0, 0] [%0, %1] [1, 1] : memref<?x?xi32> to memref<?x?xi32, #map0>
    %9 = memref.subview %8[%0, %1] [%2, %3] [1, 1] : memref<?x?xi32, #map0> to memref<?x?xi32, #map0>
    %10 = memref.subview %6[0, 0] [%2, %3] [1, 1] : memref<?x?xi32> to memref<?x?xi32, #map0>
    linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%9 : memref<?x?xi32, #map0>) outs(%10 : memref<?x?xi32, #map0>) {
    ^bb0(%arg0: i32, %arg1: i32):
      linalg.yield %arg0 : i32
    }
    return
  }
}

%4 already carries the information that the dim sizes are %0, %1. We don't need %8 = memref.subview ... to get the whole memref again.

matthias-springer commented 2 years ago

The input IR has the same redundancy:

    %4 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor<readonly:?x?xi32>{%0, %1}
    %6 = flow.dispatch.tensor.load %4, offsets = [0, 0], sizes = [%0, %1], strides = [1, 1] : !flow.dispatch.tensor<readonly:?x?xi32>{%0, %1} -> tensor<?x?xi32>

I would add a folding pattern for subview(subspan) after bufferization. Does this sound reasonable?