iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.59k stars 580 forks source link

[CPU] Unnecessary stack allocation for unpack + transpose + pack #16772

Open pzread opened 6 months ago

pzread commented 6 months ago

When compiling the example below with:

iree-compile unpack_transpose_pack.mlir -o /dev/null --iree-hal-target-backends=llvm-cpu --iree-input-type=none --iree-llvmcpu-target-triple=x86_64-unknown-linux-gnu --iree-llvmcpu-target-cpu=cascadelake --iree-opt-data-tiling=true --iree-llvmcpu-enable-ukernels=all --mlir-print-ir-after-all
func.func @predict_dispatch_17_unpack_transpose_128x12x32_f32_pack(%6: tensor<12x8x2x16x16xf32>) -> tensor<8x12x32x16x1xf32> {
  %7 = tensor.empty() : tensor<8x12x32x16x1xf32>
  %8 = tensor.empty() : tensor<128x12x32xf32>
  %9 = tensor.empty() : tensor<12x128x32xf32>
  %unpack = tensor.unpack %6 outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [16, 16] into %9 : tensor<12x8x2x16x16xf32> -> tensor<12x128x32xf32>
  %10 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%unpack : tensor<12x128x32xf32>) outs(%8 : tensor<128x12x32xf32>) {
  ^bb0(%in: f32, %out: f32):
    linalg.yield %in : f32
  } -> tensor<128x12x32xf32>
  %pack = tensor.pack %10 outer_dims_perm = [0, 1, 2] inner_dims_pos = [0, 2] inner_tiles = [16, 1] into %7 : tensor<128x12x32xf32> -> tensor<8x12x32x16x1xf32>
  return %pack : tensor<8x12x32x16x1xf32>
}

We can see that the tensor.insert_slice : tensor<16x16xf32> -> tensor<1x16x16xf32> is created for the transfer_read after vectorization and it turns in the memref.alloca and memcpy after bufferization, which then persists through the whole codegen:

// -----// IR Dump After GenericVectorization (iree-codegen-generic-vectorization) //----- //
func.func @predict_dispatch_17_unpack_transpose_128x12x32_f32_pack_dispatch_0_unpack_transpose_128x12x32_f32_pack() {
  %cst = arith.constant 0.000000e+00 : f32
  %c16 = arith.constant 16 : index
  %c1 = arith.constant 1 : index
  %c4 = arith.constant 4 : index
  %c32 = arith.constant 32 : index
  %c12 = arith.constant 12 : index
  %c8 = arith.constant 8 : index
  %c0 = arith.constant 0 : index
  %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<12x8x2x16x16xf32>>
  %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<8x12x32x16x1xf32>>
  %workgroup_id_x = hal.interface.workgroup.id[0] : index
  %workgroup_count_x = hal.interface.workgroup.count[0] : index
  %workgroup_id_y = hal.interface.workgroup.id[1] : index
  %workgroup_count_y = hal.interface.workgroup.count[1] : index
  %workgroup_id_z = hal.interface.workgroup.id[2] : index
  %workgroup_count_z = hal.interface.workgroup.count[2] : index
  %2 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_z]
  %3 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_count_z]
  scf.for %arg0 = %2 to %c8 step %3 {
    %4 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_y]
    %5 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_count_y]
    scf.for %arg1 = %4 to %c12 step %5 {
      %6 = affine.apply affine_map<()[s0] -> (s0 * 16)>()[%workgroup_id_x]
      %7 = affine.apply affine_map<()[s0] -> (s0 * 16)>()[%workgroup_count_x]
      scf.for %arg2 = %6 to %c32 step %7 {
        %8 = flow.dispatch.tensor.load %1, offsets = [%arg0, %arg1, %arg2, 0, 0], sizes = [4, 4, 16, 16, 1], strides = [1, 1, 1, 1, 1] : !flow.dispatch.tensor<writeonly:tensor<8x12x32x16x1xf32>> -> tensor<4x4x16x16x1xf32>
        %9 = affine.apply affine_map<(d0) -> (d0 floordiv 16)>(%arg2)
        %10 = flow.dispatch.tensor.load %0, offsets = [%arg1, %arg0, %9, 0, 0], sizes = [4, 4, 1, 16, 16], strides = [1, 1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<12x8x2x16x16xf32>> -> tensor<4x4x1x16x16xf32>
        %11 = scf.for %arg3 = %c0 to %c4 step %c1 iter_args(%arg4 = %8) -> (tensor<4x4x16x16x1xf32>) {
          %12 = scf.for %arg5 = %c0 to %c4 step %c1 iter_args(%arg6 = %arg4) -> (tensor<4x4x16x16x1xf32>) {
            %13 = scf.for %arg7 = %c0 to %c16 step %c1 iter_args(%arg8 = %arg6) -> (tensor<4x4x16x16x1xf32>) {
              %14 = affine.apply affine_map<(d0) -> (d0 floordiv 16)>(%arg7)
              %15 = affine.apply affine_map<(d0) -> (d0 mod 16)>(%arg7)
              %extracted_slice = tensor.extract_slice %10[%arg5, %arg3, %14, 0, 0] [1, 1, 1, 16, 16] [1, 1, 1, 1, 1] : tensor<4x4x1x16x16xf32> to tensor<1x1x1x16x16xf32>
              %16 = tensor.empty() : tensor<1x16x16xf32>
              %extracted_slice_0 = tensor.extract_slice %extracted_slice[0, 0, 0, 0, 0] [1, 1, 1, 16, 16] [1, 1, 1, 1, 1] : tensor<1x1x1x16x16xf32> to tensor<16x16xf32>
              // This insert_slice turns in the alloca and memcpy later.
              %inserted_slice = tensor.insert_slice %extracted_slice_0 into %16[0, 0, 0] [1, 16, 16] [1, 1, 1] : tensor<16x16xf32> into tensor<1x16x16xf32>
              %extracted_slice_1 = tensor.extract_slice %inserted_slice[0, 0, %15] [1, 16, 1] [1, 1, 1] : tensor<1x16x16xf32> to tensor<1x16x1xf32>
              %17 = tensor.empty() : tensor<16x1x1xf32>
              %18 = vector.transfer_read %extracted_slice_1[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : tensor<1x16x1xf32>, vector<1x16x1xf32>
              %19 = vector.transpose %18, [1, 0, 2] : vector<1x16x1xf32> to vector<16x1x1xf32>
              %20 = vector.transfer_write %19, %17[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<16x1x1xf32>, tensor<16x1x1xf32>
              %extracted_slice_2 = tensor.extract_slice %arg8[%arg3, %arg5, %arg7, 0, 0] [1, 1, 1, 16, 1] [1, 1, 1, 1, 1] : tensor<4x4x16x16x1xf32> to tensor<1x1x1x16x1xf32>
              %extracted_slice_3 = tensor.extract_slice %20[0, 0, 0] [16, 1, 1] [1, 1, 1] : tensor<16x1x1xf32> to tensor<16x1xf32>
              %inserted_slice_4 = tensor.insert_slice %extracted_slice_3 into %extracted_slice_2[0, 0, 0, 0, 0] [1, 1, 1, 16, 1] [1, 1, 1, 1, 1] : tensor<16x1xf32> into tensor<1x1x1x16x1xf32>
              %inserted_slice_5 = tensor.insert_slice %inserted_slice_4 into %arg8[%arg3, %arg5, %arg7, 0, 0] [1, 1, 1, 16, 1] [1, 1, 1, 1, 1] : tensor<1x1x1x16x1xf32> into tensor<4x4x16x16x1xf32>
              scf.yield %inserted_slice_5 : tensor<4x4x16x16x1xf32>
            }
            scf.yield %13 : tensor<4x4x16x16x1xf32>
          }
          scf.yield %12 : tensor<4x4x16x16x1xf32>
        }
        flow.dispatch.tensor.store %11, %1, offsets = [%arg0, %arg1, %arg2, 0, 0], sizes = [4, 4, 16, 16, 1], strides = [1, 1, 1, 1, 1] : tensor<4x4x16x16x1xf32> -> !flow.dispatch.tensor<writeonly:tensor<8x12x32x16x1xf32>>
      }
    }
  }
  return
}

// -----// IR Dump After IREEComprehensiveBufferize (iree-codegen-iree-comprehensive-bufferize) //----- //
module {
  func.func @predict_dispatch_17_unpack_transpose_128x12x32_f32_pack_dispatch_0_unpack_transpose_128x12x32_f32_pack() {
    %c0 = arith.constant 0 : index
    %c8 = arith.constant 8 : index
    %c12 = arith.constant 12 : index
    %c32 = arith.constant 32 : index
    %c4 = arith.constant 4 : index
    %c1 = arith.constant 1 : index
    %c16 = arith.constant 16 : index
    %cst = arith.constant 0.000000e+00 : f32
    %alloca = memref.alloca() {alignment = 64 : i64} : memref<16x1x1xf32>
    %alloca_0 = memref.alloca() {alignment = 64 : i64} : memref<1x16x16xf32>
    %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<12x8x2x16x16xf32, #hal.descriptor_type<storage_buffer>>
    memref.assume_alignment %0, 64 : memref<12x8x2x16x16xf32, #hal.descriptor_type<storage_buffer>>
    %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : memref<8x12x32x16x1xf32, #hal.descriptor_type<storage_buffer>>
    memref.assume_alignment %1, 64 : memref<8x12x32x16x1xf32, #hal.descriptor_type<storage_buffer>>
    %workgroup_id_x = hal.interface.workgroup.id[0] : index
    %workgroup_count_x = hal.interface.workgroup.count[0] : index
    %workgroup_id_y = hal.interface.workgroup.id[1] : index
    %workgroup_count_y = hal.interface.workgroup.count[1] : index
    %workgroup_id_z = hal.interface.workgroup.id[2] : index
    %workgroup_count_z = hal.interface.workgroup.count[2] : index
    %2 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_z]
    %3 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_count_z]
    %4 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_y]
    %5 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_count_y]
    %6 = affine.apply affine_map<()[s0] -> (s0 * 16)>()[%workgroup_id_x]
    %7 = affine.apply affine_map<()[s0] -> (s0 * 16)>()[%workgroup_count_x]
    scf.for %arg0 = %2 to %c8 step %3 {
      scf.for %arg1 = %4 to %c12 step %5 {
        scf.for %arg2 = %6 to %c32 step %7 {
          %subview = memref.subview %1[%arg0, %arg1, %arg2, 0, 0] [4, 4, 16, 16, 1] [1, 1, 1, 1, 1] : memref<8x12x32x16x1xf32, #hal.descriptor_type<storage_buffer>> to memref<4x4x16x16x1xf32, strided<[6144, 512, 16, 1, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
          %8 = affine.apply affine_map<(d0) -> (d0 floordiv 16)>(%arg2)
          %subview_1 = memref.subview %0[%arg1, %arg0, %8, 0, 0] [4, 4, 1, 16, 16] [1, 1, 1, 1, 1] : memref<12x8x2x16x16xf32, #hal.descriptor_type<storage_buffer>> to memref<4x4x1x16x16xf32, strided<[4096, 512, 256, 16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
          %9 = scf.for %arg3 = %c0 to %c4 step %c1 iter_args(%arg4 = %subview) -> (memref<4x4x16x16x1xf32, strided<[6144, 512, 16, 1, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) {
            %10 = scf.for %arg5 = %c0 to %c4 step %c1 iter_args(%arg6 = %arg4) -> (memref<4x4x16x16x1xf32, strided<[6144, 512, 16, 1, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) {
              %11 = scf.for %arg7 = %c0 to %c16 step %c1 iter_args(%arg8 = %arg6) -> (memref<4x4x16x16x1xf32, strided<[6144, 512, 16, 1, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) {
                %12 = affine.apply affine_map<(d0) -> (d0 floordiv 16)>(%arg7)
                %13 = affine.apply affine_map<(d0) -> (d0 mod 16)>(%arg7)
                %subview_3 = memref.subview %subview_1[%arg5, %arg3, %12, 0, 0] [1, 1, 1, 16, 16] [1, 1, 1, 1, 1] : memref<4x4x1x16x16xf32, strided<[4096, 512, 256, 16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1x1x1x16x16xf32, strided<[4096, 512, 256, 16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
                %subview_4 = memref.subview %subview_3[0, 0, 0, 0, 0] [1, 1, 1, 16, 16] [1, 1, 1, 1, 1] : memref<1x1x1x16x16xf32, strided<[4096, 512, 256, 16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<16x16xf32, strided<[16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
                %subview_5 = memref.subview %alloca_0[0, 0, 0] [1, 16, 16] [1, 1, 1] : memref<1x16x16xf32> to memref<16x16xf32, strided<[16, 1]>>
                linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%subview_4 : memref<16x16xf32, strided<[16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) outs(%subview_5 : memref<16x16xf32, strided<[16, 1]>>) {
                ^bb0(%in: f32, %out: f32):
                  linalg.yield %in : f32
                }
                %14 = vector.transfer_read %alloca_0[%c0, %c0, %13], %cst {in_bounds = [true, true]} : memref<1x16x16xf32>, vector<16x1xf32>
                %15 = vector.broadcast %14 : vector<16x1xf32> to vector<1x16x1xf32>
                %16 = vector.transpose %15, [1, 0, 2] : vector<1x16x1xf32> to vector<16x1x1xf32>
                vector.transfer_write %16, %alloca[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<16x1x1xf32>, memref<16x1x1xf32>
                %subview_6 = memref.subview %arg8[%arg3, %arg5, %arg7, 0, 0] [1, 1, 1, 16, 1] [1, 1, 1, 1, 1] : memref<4x4x16x16x1xf32, strided<[6144, 512, 16, 1, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1x1x1x16x1xf32, strided<[6144, 512, 16, 1, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
                %subview_7 = memref.subview %alloca[0, 0, 0] [16, 1, 1] [1, 1, 1] : memref<16x1x1xf32> to memref<16x1xf32, strided<[1, 1]>>
                %subview_8 = memref.subview %subview_6[0, 0, 0, 0, 0] [1, 1, 1, 16, 1] [1, 1, 1, 1, 1] : memref<1x1x1x16x1xf32, strided<[6144, 512, 16, 1, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<16x1xf32, strided<[1, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
                linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%subview_7 : memref<16x1xf32, strided<[1, 1]>>) outs(%subview_8 : memref<16x1xf32, strided<[1, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) {
                ^bb0(%in: f32, %out: f32):
                  linalg.yield %in : f32
                }
                %subview_9 = memref.subview %arg8[%arg3, %arg5, %arg7, 0, 0] [1, 1, 1, 16, 1] [1, 1, 1, 1, 1] : memref<4x4x16x16x1xf32, strided<[6144, 512, 16, 1, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1x1x1x16x1xf32, strided<[6144, 512, 16, 1, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
                linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%subview_6 : memref<1x1x1x16x1xf32, strided<[6144, 512, 16, 1, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) outs(%subview_9 : memref<1x1x1x16x1xf32, strided<[6144, 512, 16, 1, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) {
                ^bb0(%in: f32, %out: f32):
                  linalg.yield %in : f32
                }
                scf.yield %arg8 : memref<4x4x16x16x1xf32, strided<[6144, 512, 16, 1, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
              }
              scf.yield %11 : memref<4x4x16x16x1xf32, strided<[6144, 512, 16, 1, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
            }
            scf.yield %10 : memref<4x4x16x16x1xf32, strided<[6144, 512, 16, 1, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
          }
          %subview_2 = memref.subview %1[%arg0, %arg1, %arg2, 0, 0] [4, 4, 16, 16, 1] [1, 1, 1, 1, 1] : memref<8x12x32x16x1xf32, #hal.descriptor_type<storage_buffer>> to memref<4x4x16x16x1xf32, strided<[6144, 512, 16, 1, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
          linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%9 : memref<4x4x16x16x1xf32, strided<[6144, 512, 16, 1, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) outs(%subview_2 : memref<4x4x16x16x1xf32, strided<[6144, 512, 16, 1, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) {
          ^bb0(%in: f32, %out: f32):
            linalg.yield %in : f32
          }
        }
      }
    }
    return
  }
}

// -----// IR Dump After LLVMCPUVirtualVectorLowering (iree-llvmcpu-virtual-vector-lowering) //----- //
func.func @predict_dispatch_17_unpack_transpose_128x12x32_f32_pack_dispatch_0_unpack_transpose_128x12x32_f32_pack() {
  %c0 = arith.constant 0 : index
  %c4 = arith.constant 4 : index
  %c1 = arith.constant 1 : index
  %c16 = arith.constant 16 : index
  %cst = arith.constant 0.000000e+00 : f32
  %alloca = memref.alloca() {alignment = 64 : i64} : memref<16x1x1xf32>
  %alloca_0 = memref.alloca() {alignment = 64 : i64} : memref<1x16x16xf32>
  %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<12x8x2x16x16xf32, #hal.descriptor_type<storage_buffer>>
  memref.assume_alignment %0, 64 : memref<12x8x2x16x16xf32, #hal.descriptor_type<storage_buffer>>
  %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : memref<8x12x32x16x1xf32, #hal.descriptor_type<storage_buffer>>
  memref.assume_alignment %1, 64 : memref<8x12x32x16x1xf32, #hal.descriptor_type<storage_buffer>>
  %workgroup_id_x = hal.interface.workgroup.id[0] : index
  %workgroup_id_y = hal.interface.workgroup.id[1] : index
  %workgroup_id_z = hal.interface.workgroup.id[2] : index
  %2 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_z]
  %3 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_y]
  %4 = affine.apply affine_map<()[s0] -> (s0 * 16)>()[%workgroup_id_x]
  %subview = memref.subview %1[%2, %3, %4, 0, 0] [4, 4, 16, 16, 1] [1, 1, 1, 1, 1] : memref<8x12x32x16x1xf32, #hal.descriptor_type<storage_buffer>> to memref<4x4x16x16x1xf32, strided<[6144, 512, 16, 1, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
  %5 = affine.apply affine_map<(d0) -> (d0 floordiv 16)>(%4)
  %subview_1 = memref.subview %0[%3, %2, %5, 0, 0] [4, 4, 1, 16, 16] [1, 1, 1, 1, 1] : memref<12x8x2x16x16xf32, #hal.descriptor_type<storage_buffer>> to memref<4x4x1x16x16xf32, strided<[4096, 512, 256, 16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
  scf.for %arg0 = %c0 to %c4 step %c1 {
    scf.for %arg1 = %c0 to %c4 step %c1 {
      scf.for %arg2 = %c0 to %c16 step %c1 {
        %6 = affine.apply affine_map<(d0) -> (d0 floordiv 16)>(%arg2)
        %7 = affine.apply affine_map<(d0) -> (d0 mod 16)>(%arg2)
        %subview_2 = memref.subview %subview_1[%arg1, %arg0, %6, 0, 0] [1, 1, 1, 16, 16] [1, 1, 1, 1, 1] : memref<4x4x1x16x16xf32, strided<[4096, 512, 256, 16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1x1x1x16x16xf32, strided<[4096, 512, 256, 16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
        %subview_3 = memref.subview %subview_2[0, 0, 0, 0, 0] [1, 1, 1, 16, 16] [1, 1, 1, 1, 1] : memref<1x1x1x16x16xf32, strided<[4096, 512, 256, 16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<16x16xf32, strided<[16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
        %subview_4 = memref.subview %alloca_0[0, 0, 0] [1, 16, 16] [1, 1, 1] : memref<1x16x16xf32> to memref<16x16xf32, strided<[16, 1]>>
        linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%subview_3 : memref<16x16xf32, strided<[16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) outs(%subview_4 : memref<16x16xf32, strided<[16, 1]>>) {
        ^bb0(%in: f32, %out: f32):
          linalg.yield %in : f32
        }
        %8 = vector.transfer_read %alloca_0[%c0, %c0, %7], %cst {in_bounds = [true, true]} : memref<1x16x16xf32>, vector<16x1xf32>
        %9 = vector.broadcast %8 : vector<16x1xf32> to vector<1x16x1xf32>
        %10 = vector.transpose %9, [1, 0, 2] : vector<1x16x1xf32> to vector<16x1x1xf32>
        %subview_5 = memref.subview %alloca[0, 0, 0] [16, 1, 1] [1, 1, 1] : memref<16x1x1xf32> to memref<16xf32>
        %11 = vector.shape_cast %10 : vector<16x1x1xf32> to vector<16xf32>
        vector.transfer_write %11, %subview_5[%c0] {in_bounds = [true]} : vector<16xf32>, memref<16xf32>
        %subview_6 = memref.subview %subview[%arg0, %arg1, %arg2, 0, 0] [1, 1, 1, 16, 1] [1, 1, 1, 1, 1] : memref<4x4x16x16x1xf32, strided<[6144, 512, 16, 1, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1x1x1x16x1xf32, strided<[6144, 512, 16, 1, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
        %subview_7 = memref.subview %alloca[0, 0, 0] [16, 1, 1] [1, 1, 1] : memref<16x1x1xf32> to memref<16x1xf32, strided<[1, 1]>>
        %subview_8 = memref.subview %subview_6[0, 0, 0, 0, 0] [1, 1, 1, 16, 1] [1, 1, 1, 1, 1] : memref<1x1x1x16x1xf32, strided<[6144, 512, 16, 1, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<16x1xf32, strided<[1, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
        linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%subview_7 : memref<16x1xf32, strided<[1, 1]>>) outs(%subview_8 : memref<16x1xf32, strided<[1, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) {
        ^bb0(%in: f32, %out: f32):
          linalg.yield %in : f32
        }
      }
    }
  }
  return
}

I think if we can get rid of the tensor.insert_slice : tensor<16x16> -> tensor<1x16x16> and generate something like the below, the stack allocation and memcpy can be avoided.

%extracted_slice = tensor.extract_slice ... : tensor<16x16xf32> -> tensor<16x1xf32>
%read = vector.transfer_read %extracted_slice : tensor<16x1xf32> -> vector<16x1xf32>
%transpose_in = vector.broadcast %read : vector<16x1xf32> -> vector<1x16x1xf32>
%transpose_out  = vector.transpose %transpose_in, [1, 0, 2] : vector<1x16x1xf32> -> vector<16x1x1xf32>

This regresses the MiniLM latency when working on #16682 to create more pack/unpack fusions

dcaballe commented 6 months ago

@hanhanW, @MaheshRavishankar, @matthias-springer?

matthias-springer commented 6 months ago

There are two tensor.empty:

              %16 = tensor.empty() : tensor<1x16x16xf32>
              %extracted_slice_0 = tensor.extract_slice %extracted_slice[0, 0, 0, 0, 0] [1, 1, 1, 16, 16] [1, 1, 1, 1, 1] : tensor<1x1x1x16x16xf32> to tensor<16x16xf32>
              // This insert_slice turns in the alloca and memcpy later.
              %inserted_slice = tensor.insert_slice %extracted_slice_0 into %16[0, 0, 0] [1, 16, 16] [1, 1, 1] : tensor<16x16xf32> into tensor<1x16x16xf32>
              %extracted_slice_1 = tensor.extract_slice %inserted_slice[0, 0, %15] [1, 16, 1] [1, 1, 1] : tensor<1x16x16xf32> to tensor<1x16x1xf32>

I think this one could be turned into a single tensor.extract_slice %extracted_slice. This looks like an extraction with a rank-reduction to me. A simple cleanup/canonicalization pattern that folds extract_slice-insert_slice-extract_slice chains should do. I think we even have a pass upstream that folds such ops. The tricky thing is that you don't want to fold too much because it could mess with bufferization, which looks for matching extract_slice-insert_slice pairs.

For the second tensor.empty (if you want to get rid of it), I would try to make the vector.transfer_write write directly into %arg8 instead of %17. You may be able to achieve that with the right extract_slice/insert_slice/transfer_write foldings.

MaheshRavishankar commented 6 months ago

Have you tried folding the transpose with the pack operation?

hanhanW commented 6 months ago

We should enable direct vectorization, so we no longer see any empty ops... I have a working chain, but there are things need to be fixed: https://github.com/openxla/iree/pull/16672

pzread commented 6 months ago

Merged https://github.com/llvm/llvm-project/pull/86328 to provide a pattern to fold insert_slice of extract_slice as @matthias-springer suggested

pzread commented 6 months ago

I tried to enable https://github.com/llvm/llvm-project/pull/86328 in IREE after vectorization but it seems to regress the DeepLab.

In one of the regressed dispatches I found that the folding of insert_slice of extract_slice seems to break the bufferization.

With the new patterns to fold insert_slice of extract_slice:

// -----// IR Dump After GenericVectorization (iree-codegen-generic-vectorization) //----- //
func.func @main_dispatch_108_unpack_generic_1089x480_f32() {
  %cst = arith.constant dense<0.000000e+00> : vector<16x16xf32>
  %cst_0 = arith.constant dense<6.000000e+00> : vector<16x16xf32>
  %c48 = arith.constant 48 : index
  %c16 = arith.constant 16 : index
  %c0 = arith.constant 0 : index
  %c480 = arith.constant 480 : index
  %c1089 = arith.constant 1089 : index
  %cst_1 = arith.constant 0.000000e+00 : f32
  %0 = hal.interface.constant.load[0] : i32
  %1 = hal.interface.constant.load[1] : i32
  %2 = hal.interface.constant.load[2] : i32
  %3 = arith.index_castui %0 : i32 to index
  %4 = arith.index_castui %1 : i32 to index
  %5 = arith.index_castui %2 : i32 to index
  %6 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%3) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<69x30x16x16xf32>>
  %7 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%4) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<480xf32>>
  %8 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%5) : !flow.dispatch.tensor<writeonly:tensor<1089x480xf32>>
  %workgroup_id_x = hal.interface.workgroup.id[0] : index
  %workgroup_count_x = hal.interface.workgroup.count[0] : index
  %workgroup_id_y = hal.interface.workgroup.id[1] : index
  %workgroup_count_y = hal.interface.workgroup.count[1] : index
  %9 = affine.apply affine_map<()[s0] -> (s0 * 144)>()[%workgroup_id_y]
  %10 = affine.apply affine_map<()[s0] -> (s0 * 144)>()[%workgroup_count_y]
  scf.for %arg0 = %9 to %c1089 step %10 {
    %11 = affine.min affine_map<(d0) -> (-d0 + 1089, 144)>(%arg0)
    %12 = affine.apply affine_map<()[s0] -> (s0 * 48)>()[%workgroup_id_x]
    %13 = affine.apply affine_map<()[s0] -> (s0 * 48)>()[%workgroup_count_x]
    scf.for %arg1 = %12 to %c480 step %13 {
      %14 = flow.dispatch.tensor.load %8, offsets = [%arg0, %arg1], sizes = [%11, 48], strides = [1, 1] : !flow.dispatch.tensor<writeonly:tensor<1089x480xf32>> -> tensor<?x48xf32>
      %15 = affine.apply affine_map<(d0) -> (d0 ceildiv 16)>(%11)
      %16 = affine.apply affine_map<(d0) -> (d0 floordiv 16)>(%arg0)
      %17 = affine.apply affine_map<(d0) -> (d0 floordiv 16)>(%arg1)
      %18 = flow.dispatch.tensor.load %6, offsets = [%16, %17, 0, 0], sizes = [%15, 3, 16, 16], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<69x30x16x16xf32>> -> tensor<?x3x16x16xf32>
      %19 = flow.dispatch.tensor.load %7, offsets = [%arg1], sizes = [48], strides = [1] : !flow.dispatch.tensor<readonly:tensor<480xf32>> -> tensor<48xf32>
      %20 = scf.for %arg2 = %c0 to %11 step %c16 iter_args(%arg3 = %14) -> (tensor<?x48xf32>) {
        %21 = scf.for %arg4 = %c0 to %c48 step %c16 iter_args(%arg5 = %arg3) -> (tensor<?x48xf32>) {
          %22 = affine.min affine_map<(d0)[s0] -> (-d0 + s0, 16)>(%arg2)[%11]
          %extracted_slice = tensor.extract_slice %19[%arg4] [16] [1] : tensor<48xf32> to tensor<16xf32>
          %23 = affine.apply affine_map<(d0) -> (d0 floordiv 16)>(%arg2)
          %24 = affine.apply affine_map<(d0) -> (d0 floordiv 16)>(%arg4)
          %extracted_slice_2 = tensor.extract_slice %18[%23, %24, 0, 0] [1, 1, 16, 16] [1, 1, 1, 1] : tensor<?x3x16x16xf32> to tensor<1x1x16x16xf32>
          %extracted_slice_3 = tensor.extract_slice %arg5[%arg2, %arg4] [%22, 16] [1, 1] : tensor<?x48xf32> to tensor<?x16xf32>
          %extracted_slice_4 = tensor.extract_slice %extracted_slice_2[0, 0, 0, 0] [1, 1, 16, 16] [1, 1, 1, 1] : tensor<1x1x16x16xf32> to tensor<16x16xf32>
          %extracted_slice_5 = tensor.extract_slice %extracted_slice_4[0, 0] [%22, 16] [1, 1] : tensor<16x16xf32> to tensor<?x16xf32>
          %inserted_slice = tensor.insert_slice %extracted_slice_5 into %extracted_slice_3[0, 0] [%22, 16] [1, 1] : tensor<?x16xf32> into tensor<?x16xf32>
          %dim = tensor.dim %inserted_slice, %c0 : tensor<?x16xf32>
          %25 = vector.transfer_read %extracted_slice[%c0], %cst_1 {in_bounds = [true]} : tensor<16xf32>, vector<16xf32>
          %26 = vector.broadcast %25 : vector<16xf32> to vector<16x16xf32>
          %27 = vector.create_mask %dim, %c16 : vector<16x16xi1>
          %28 = vector.transfer_read %inserted_slice[%c0, %c0], %cst_1, %27 {in_bounds = [true, true]} : tensor<?x16xf32>, vector<16x16xf32>
          %29 = arith.addf %28, %26 : vector<16x16xf32>
          %30 = arith.minimumf %29, %cst_0 : vector<16x16xf32>
          %31 = arith.maximumf %30, %cst : vector<16x16xf32>
          %32 = vector.transfer_write %31, %inserted_slice[%c0, %c0], %27 {in_bounds = [true, true]} : vector<16x16xf32>, tensor<?x16xf32>
          %inserted_slice_6 = tensor.insert_slice %32 into %arg5[%arg2, %arg4] [%22, 16] [1, 1] : tensor<?x16xf32> into tensor<?x48xf32>
          scf.yield %inserted_slice_6 : tensor<?x48xf32>
        }
        scf.yield %21 : tensor<?x48xf32>
      }
      flow.dispatch.tensor.store %20, %8, offsets = [%arg0, %arg1], sizes = [%11, 48], strides = [1, 1] : tensor<?x48xf32> -> !flow.dispatch.tensor<writeonly:tensor<1089x480xf32>>
    }
  }
  return
}

// -----// IR Dump After OptimizeTensorInsertExtractSlices (iree-codegen-optimize-tensor-insert-extract-slices) //----- //
func.func @main_dispatch_108_unpack_generic_1089x480_f32() {
  %cst = arith.constant dense<0.000000e+00> : vector<16x16xf32>
  %cst_0 = arith.constant dense<6.000000e+00> : vector<16x16xf32>
  %c48 = arith.constant 48 : index
  %c16 = arith.constant 16 : index
  %c0 = arith.constant 0 : index
  %c480 = arith.constant 480 : index
  %c1089 = arith.constant 1089 : index
  %cst_1 = arith.constant 0.000000e+00 : f32
  %0 = hal.interface.constant.load[0] : i32
  %1 = hal.interface.constant.load[1] : i32
  %2 = hal.interface.constant.load[2] : i32
  %3 = arith.index_castui %0 : i32 to index
  %4 = arith.index_castui %1 : i32 to index
  %5 = arith.index_castui %2 : i32 to index
  %6 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%3) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<69x30x16x16xf32>>
  %7 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%4) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<480xf32>>
  %8 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%5) : !flow.dispatch.tensor<writeonly:tensor<1089x480xf32>>
  %workgroup_id_x = hal.interface.workgroup.id[0] : index
  %workgroup_count_x = hal.interface.workgroup.count[0] : index
  %workgroup_id_y = hal.interface.workgroup.id[1] : index
  %workgroup_count_y = hal.interface.workgroup.count[1] : index
  %9 = affine.apply affine_map<()[s0] -> (s0 * 144)>()[%workgroup_id_y]
  %10 = affine.apply affine_map<()[s0] -> (s0 * 144)>()[%workgroup_count_y]
  %11 = affine.apply affine_map<()[s0] -> (s0 * 48)>()[%workgroup_id_x]
  %12 = affine.apply affine_map<()[s0] -> (s0 * 48)>()[%workgroup_count_x]
  scf.for %arg0 = %9 to %c1089 step %10 {
    %13 = affine.min affine_map<(d0) -> (-d0 + 1089, 144)>(%arg0)
    %14 = affine.apply affine_map<(d0) -> (d0 ceildiv 16)>(%13)
    %15 = affine.apply affine_map<(d0) -> (d0 floordiv 16)>(%arg0)
    scf.for %arg1 = %11 to %c480 step %12 {
      %16 = flow.dispatch.tensor.load %8, offsets = [%arg0, %arg1], sizes = [%13, 48], strides = [1, 1] : !flow.dispatch.tensor<writeonly:tensor<1089x480xf32>> -> tensor<?x48xf32>
      %17 = affine.apply affine_map<(d0) -> (d0 floordiv 16)>(%arg1)
      %18 = flow.dispatch.tensor.load %6, offsets = [%15, %17, 0, 0], sizes = [%14, 3, 16, 16], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<69x30x16x16xf32>> -> tensor<?x3x16x16xf32>
      %19 = flow.dispatch.tensor.load %7, offsets = [%arg1], sizes = [48], strides = [1] : !flow.dispatch.tensor<readonly:tensor<480xf32>> -> tensor<48xf32>
      %20 = scf.for %arg2 = %c0 to %13 step %c16 iter_args(%arg3 = %16) -> (tensor<?x48xf32>) {
        %21 = affine.min affine_map<(d0)[s0] -> (-d0 + s0, 16)>(%arg2)[%13]
        %22 = affine.apply affine_map<(d0) -> (d0 floordiv 16)>(%arg2)
        %23 = scf.for %arg4 = %c0 to %c48 step %c16 iter_args(%arg5 = %arg3) -> (tensor<?x48xf32>) {
          %24 = affine.apply affine_map<(d0) -> (d0 floordiv 16)>(%arg4)
          %extracted_slice = tensor.extract_slice %18[%22, %24, 0, 0] [1, 1, 16, 16] [1, 1, 1, 1] : tensor<?x3x16x16xf32> to tensor<1x1x16x16xf32>
          %extracted_slice_2 = tensor.extract_slice %extracted_slice[0, 0, 0, 0] [1, 1, 16, 16] [1, 1, 1, 1] : tensor<1x1x16x16xf32> to tensor<16x16xf32>
          %extracted_slice_3 = tensor.extract_slice %extracted_slice_2[0, 0] [%21, 16] [1, 1] : tensor<16x16xf32> to tensor<?x16xf32>
          %25 = vector.transfer_read %19[%arg4], %cst_1 {in_bounds = [true]} : tensor<48xf32>, vector<16xf32>
          %26 = vector.broadcast %25 : vector<16xf32> to vector<16x16xf32>
          %27 = vector.create_mask %21, %c16 : vector<16x16xi1>
          %28 = vector.transfer_read %extracted_slice_3[%c0, %c0], %cst_1, %27 {in_bounds = [true, true]} : tensor<?x16xf32>, vector<16x16xf32>
          %29 = arith.addf %28, %26 : vector<16x16xf32>
          %30 = arith.minimumf %29, %cst_0 : vector<16x16xf32>
          %31 = arith.maximumf %30, %cst : vector<16x16xf32>
          %32 = vector.transfer_write %31, %extracted_slice_3[%c0, %c0], %27 {in_bounds = [true, true]} : vector<16x16xf32>, tensor<?x16xf32>
          %inserted_slice = tensor.insert_slice %32 into %arg5[%arg2, %arg4] [%21, 16] [1, 1] : tensor<?x16xf32> into tensor<?x48xf32>
          scf.yield %inserted_slice : tensor<?x48xf32>
        }
        scf.yield %23 : tensor<?x48xf32>
      }
      flow.dispatch.tensor.store %20, %8, offsets = [%arg0, %arg1], sizes = [%13, 48], strides = [1, 1] : tensor<?x48xf32> -> !flow.dispatch.tensor<writeonly:tensor<1089x480xf32>>
    }
  }
  return
}

// ...

// -----// IR Dump After IREEComprehensiveBufferize (iree-codegen-iree-comprehensive-bufferize) //----- //
module {
  func.func @main_dispatch_108_unpack_generic_1089x480_f32() {
    %cst = arith.constant 0.000000e+00 : f32
    %c1089 = arith.constant 1089 : index
    %c480 = arith.constant 480 : index
    %c0 = arith.constant 0 : index
    %c16 = arith.constant 16 : index
    %c48 = arith.constant 48 : index
    %cst_0 = arith.constant dense<6.000000e+00> : vector<16x16xf32>
    %cst_1 = arith.constant dense<0.000000e+00> : vector<16x16xf32>
    %alloca = memref.alloca() {alignment = 64 : i64} : memref<16x16xf32, #hal.descriptor_type<storage_buffer>>
    %0 = hal.interface.constant.load[0] : i32
    %1 = hal.interface.constant.load[1] : i32
    %2 = hal.interface.constant.load[2] : i32
    %3 = arith.index_castui %0 : i32 to index
    %4 = arith.index_castui %1 : i32 to index
    %5 = arith.index_castui %2 : i32 to index
    %6 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%3) flags(ReadOnly) : memref<69x30x16x16xf32, strided<[7680, 256, 16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
    memref.assume_alignment %6, 1 : memref<69x30x16x16xf32, strided<[7680, 256, 16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
    %7 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%4) flags(ReadOnly) : memref<480xf32, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
    memref.assume_alignment %7, 1 : memref<480xf32, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
    %8 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%5) : memref<1089x480xf32, strided<[480, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
    memref.assume_alignment %8, 1 : memref<1089x480xf32, strided<[480, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
    %workgroup_id_x = hal.interface.workgroup.id[0] : index
    %workgroup_count_x = hal.interface.workgroup.count[0] : index
    %workgroup_id_y = hal.interface.workgroup.id[1] : index
    %workgroup_count_y = hal.interface.workgroup.count[1] : index
    %9 = affine.apply affine_map<()[s0] -> (s0 * 144)>()[%workgroup_id_y]
    %10 = affine.apply affine_map<()[s0] -> (s0 * 144)>()[%workgroup_count_y]
    %11 = affine.apply affine_map<()[s0] -> (s0 * 48)>()[%workgroup_id_x]
    %12 = affine.apply affine_map<()[s0] -> (s0 * 48)>()[%workgroup_count_x]
    scf.for %arg0 = %9 to %c1089 step %10 {
      %13 = affine.min affine_map<(d0) -> (-d0 + 1089, 144)>(%arg0)
      %14 = affine.apply affine_map<(d0) -> (d0 ceildiv 16)>(%13)
      %15 = affine.apply affine_map<(d0) -> (d0 floordiv 16)>(%arg0)
      scf.for %arg1 = %11 to %c480 step %12 {
        %subview = memref.subview %8[%arg0, %arg1] [%13, 48] [1, 1] : memref<1089x480xf32, strided<[480, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<?x48xf32, strided<[480, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
        %16 = affine.apply affine_map<(d0) -> (d0 floordiv 16)>(%arg1)
        %subview_2 = memref.subview %6[%15, %16, 0, 0] [%14, 3, 16, 16] [1, 1, 1, 1] : memref<69x30x16x16xf32, strided<[7680, 256, 16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<?x3x16x16xf32, strided<[7680, 256, 16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
        %subview_3 = memref.subview %7[%arg1] [48] [1] : memref<480xf32, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<48xf32, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
        %17 = scf.for %arg2 = %c0 to %13 step %c16 iter_args(%arg3 = %subview) -> (memref<?x48xf32, strided<[480, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) {
          %18 = affine.min affine_map<(d0)[s0] -> (-d0 + s0, 16)>(%arg2)[%13]
          %19 = affine.apply affine_map<(d0) -> (d0 floordiv 16)>(%arg2)
          %20 = scf.for %arg4 = %c0 to %c48 step %c16 iter_args(%arg5 = %arg3) -> (memref<?x48xf32, strided<[480, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) {
            %21 = affine.apply affine_map<(d0) -> (d0 floordiv 16)>(%arg4)
            %subview_5 = memref.subview %subview_2[%19, %21, 0, 0] [1, 1, 16, 16] [1, 1, 1, 1] : memref<?x3x16x16xf32, strided<[7680, 256, 16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1x1x16x16xf32, strided<[7680, 256, 16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
            %subview_6 = memref.subview %subview_5[0, 0, 0, 0] [1, 1, 16, 16] [1, 1, 1, 1] : memref<1x1x16x16xf32, strided<[7680, 256, 16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<16x16xf32, strided<[16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
            %subview_7 = memref.subview %subview_6[0, 0] [%18, 16] [1, 1] : memref<16x16xf32, strided<[16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<?x16xf32, strided<[16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
            %22 = vector.transfer_read %subview_3[%arg4], %cst {in_bounds = [true]} : memref<48xf32, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<16xf32>
            %23 = vector.broadcast %22 : vector<16xf32> to vector<16x16xf32>
            %24 = vector.create_mask %18, %c16 : vector<16x16xi1>
            %25 = vector.transfer_read %subview_7[%c0, %c0], %cst, %24 {in_bounds = [true, true]} : memref<?x16xf32, strided<[16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<16x16xf32>
            %26 = arith.addf %25, %23 : vector<16x16xf32>
            %27 = arith.minimumf %26, %cst_0 : vector<16x16xf32>
            %28 = arith.maximumf %27, %cst_1 : vector<16x16xf32>
            %subview_8 = memref.subview %alloca[0, 0] [%18, 16] [1, 1] : memref<16x16xf32, #hal.descriptor_type<storage_buffer>> to memref<?x16xf32, strided<[16, 1]>, #hal.descriptor_type<storage_buffer>>
            linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%subview_7 : memref<?x16xf32, strided<[16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) outs(%subview_8 : memref<?x16xf32, strided<[16, 1]>, #hal.descriptor_type<storage_buffer>>) {
            ^bb0(%in: f32, %out: f32):
              linalg.yield %in : f32
            }
            vector.transfer_write %28, %subview_8[%c0, %c0], %24 {in_bounds = [true, true]} : vector<16x16xf32>, memref<?x16xf32, strided<[16, 1]>, #hal.descriptor_type<storage_buffer>>
            %subview_9 = memref.subview %arg5[%arg2, %arg4] [%18, 16] [1, 1] : memref<?x48xf32, strided<[480, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<?x16xf32, strided<[480, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
            linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%subview_8 : memref<?x16xf32, strided<[16, 1]>, #hal.descriptor_type<storage_buffer>>) outs(%subview_9 : memref<?x16xf32, strided<[480, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) {
            ^bb0(%in: f32, %out: f32):
              linalg.yield %in : f32
            }
            scf.yield %arg5 : memref<?x48xf32, strided<[480, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
          }
          scf.yield %20 : memref<?x48xf32, strided<[480, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
        }
        %subview_4 = memref.subview %8[%arg0, %arg1] [%13, 48] [1, 1] : memref<1089x480xf32, strided<[480, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<?x48xf32, strided<[480, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
        linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%17 : memref<?x48xf32, strided<[480, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) outs(%subview_4 : memref<?x48xf32, strided<[480, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) {
        ^bb0(%in: f32, %out: f32):
          linalg.yield %in : f32
        }
      }
    }
    return
  }
}

Originally it was:

// -----// IR Dump After GenericVectorization (iree-codegen-generic-vectorization) //----- //
// Unchanged

// -----// IR Dump After OptimizeTensorInsertExtractSlices (iree-codegen-optimize-tensor-insert-extract-slices) //----- //
func.func @main_dispatch_108_unpack_generic_1089x480_f32() {
  %cst = arith.constant dense<0.000000e+00> : vector<16x16xf32>
  %cst_0 = arith.constant dense<6.000000e+00> : vector<16x16xf32>
  %c48 = arith.constant 48 : index
  %c16 = arith.constant 16 : index
  %c0 = arith.constant 0 : index
  %c480 = arith.constant 480 : index
  %c1089 = arith.constant 1089 : index
  %cst_1 = arith.constant 0.000000e+00 : f32
  %0 = hal.interface.constant.load[0] : i32
  %1 = hal.interface.constant.load[1] : i32
  %2 = hal.interface.constant.load[2] : i32
  %3 = arith.index_castui %0 : i32 to index
  %4 = arith.index_castui %1 : i32 to index
  %5 = arith.index_castui %2 : i32 to index
  %6 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%3) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<69x30x16x16xf32>>
  %7 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%4) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<480xf32>>
  %8 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%5) : !flow.dispatch.tensor<writeonly:tensor<1089x480xf32>>
  %workgroup_id_x = hal.interface.workgroup.id[0] : index
  %workgroup_count_x = hal.interface.workgroup.count[0] : index
  %workgroup_id_y = hal.interface.workgroup.id[1] : index
  %workgroup_count_y = hal.interface.workgroup.count[1] : index
  %9 = affine.apply affine_map<()[s0] -> (s0 * 144)>()[%workgroup_id_y]
  %10 = affine.apply affine_map<()[s0] -> (s0 * 144)>()[%workgroup_count_y]
  %11 = affine.apply affine_map<()[s0] -> (s0 * 48)>()[%workgroup_id_x]
  %12 = affine.apply affine_map<()[s0] -> (s0 * 48)>()[%workgroup_count_x]
  scf.for %arg0 = %9 to %c1089 step %10 {
    %13 = affine.min affine_map<(d0) -> (-d0 + 1089, 144)>(%arg0)
    %14 = affine.apply affine_map<(d0) -> (d0 ceildiv 16)>(%13)
    %15 = affine.apply affine_map<(d0) -> (d0 floordiv 16)>(%arg0)
    scf.for %arg1 = %11 to %c480 step %12 {
      %16 = flow.dispatch.tensor.load %8, offsets = [%arg0, %arg1], sizes = [%13, 48], strides = [1, 1] : !flow.dispatch.tensor<writeonly:tensor<1089x480xf32>> -> tensor<?x48xf32>
      %17 = affine.apply affine_map<(d0) -> (d0 floordiv 16)>(%arg1)
      %18 = flow.dispatch.tensor.load %6, offsets = [%15, %17, 0, 0], sizes = [%14, 3, 16, 16], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<69x30x16x16xf32>> -> tensor<?x3x16x16xf32>
      %19 = flow.dispatch.tensor.load %7, offsets = [%arg1], sizes = [48], strides = [1] : !flow.dispatch.tensor<readonly:tensor<480xf32>> -> tensor<48xf32>
      %20 = scf.for %arg2 = %c0 to %13 step %c16 iter_args(%arg3 = %16) -> (tensor<?x48xf32>) {
        %21 = affine.min affine_map<(d0)[s0] -> (-d0 + s0, 16)>(%arg2)[%13]
        %22 = affine.apply affine_map<(d0) -> (d0 floordiv 16)>(%arg2)
        %23 = scf.for %arg4 = %c0 to %c48 step %c16 iter_args(%arg5 = %arg3) -> (tensor<?x48xf32>) {
          %24 = affine.apply affine_map<(d0) -> (d0 floordiv 16)>(%arg4)
          %extracted_slice = tensor.extract_slice %18[%22, %24, 0, 0] [1, 1, 16, 16] [1, 1, 1, 1] : tensor<?x3x16x16xf32> to tensor<1x1x16x16xf32>
          %extracted_slice_2 = tensor.extract_slice %arg5[%arg2, %arg4] [%21, 16] [1, 1] : tensor<?x48xf32> to tensor<?x16xf32>
          %extracted_slice_3 = tensor.extract_slice %extracted_slice[0, 0, 0, 0] [1, 1, 16, 16] [1, 1, 1, 1] : tensor<1x1x16x16xf32> to tensor<16x16xf32>
          %extracted_slice_4 = tensor.extract_slice %extracted_slice_3[0, 0] [%21, 16] [1, 1] : tensor<16x16xf32> to tensor<?x16xf32>
          %inserted_slice = tensor.insert_slice %extracted_slice_4 into %extracted_slice_2[0, 0] [%21, 16] [1, 1] : tensor<?x16xf32> into tensor<?x16xf32>
          %dim = tensor.dim %inserted_slice, %c0 : tensor<?x16xf32>
          %25 = vector.transfer_read %19[%arg4], %cst_1 {in_bounds = [true]} : tensor<48xf32>, vector<16xf32>
          %26 = vector.broadcast %25 : vector<16xf32> to vector<16x16xf32>
          %27 = vector.create_mask %dim, %c16 : vector<16x16xi1>
          %28 = vector.transfer_read %inserted_slice[%c0, %c0], %cst_1, %27 {in_bounds = [true, true]} : tensor<?x16xf32>, vector<16x16xf32>
          %29 = arith.addf %28, %26 : vector<16x16xf32>
          %30 = arith.minimumf %29, %cst_0 : vector<16x16xf32>
          %31 = arith.maximumf %30, %cst : vector<16x16xf32>
          %32 = vector.transfer_write %31, %inserted_slice[%c0, %c0], %27 {in_bounds = [true, true]} : vector<16x16xf32>, tensor<?x16xf32>
          %inserted_slice_5 = tensor.insert_slice %32 into %arg5[%arg2, %arg4] [%21, 16] [1, 1] : tensor<?x16xf32> into tensor<?x48xf32>
          scf.yield %inserted_slice_5 : tensor<?x48xf32>
        }
        scf.yield %23 : tensor<?x48xf32>
      }
      flow.dispatch.tensor.store %20, %8, offsets = [%arg0, %arg1], sizes = [%13, 48], strides = [1, 1] : tensor<?x48xf32> -> !flow.dispatch.tensor<writeonly:tensor<1089x480xf32>>
    }
  }
  return
}

// ...

// -----// IR Dump After IREEComprehensiveBufferize (iree-codegen-iree-comprehensive-bufferize) //----- //
module {
  func.func @main_dispatch_108_unpack_generic_1089x480_f32() {
    %cst = arith.constant dense<0.000000e+00> : vector<16x16xf32>
    %cst_0 = arith.constant dense<6.000000e+00> : vector<16x16xf32>
    %c48 = arith.constant 48 : index
    %c16 = arith.constant 16 : index
    %c0 = arith.constant 0 : index
    %c480 = arith.constant 480 : index
    %c1089 = arith.constant 1089 : index
    %cst_1 = arith.constant 0.000000e+00 : f32
    %0 = hal.interface.constant.load[0] : i32
    %1 = hal.interface.constant.load[1] : i32
    %2 = hal.interface.constant.load[2] : i32
    %3 = arith.index_castui %0 : i32 to index
    %4 = arith.index_castui %1 : i32 to index
    %5 = arith.index_castui %2 : i32 to index
    %6 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%3) flags(ReadOnly) : memref<69x30x16x16xf32, strided<[7680, 256, 16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
    memref.assume_alignment %6, 1 : memref<69x30x16x16xf32, strided<[7680, 256, 16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
    %7 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%4) flags(ReadOnly) : memref<480xf32, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
    memref.assume_alignment %7, 1 : memref<480xf32, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
    %8 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%5) : memref<1089x480xf32, strided<[480, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
    memref.assume_alignment %8, 1 : memref<1089x480xf32, strided<[480, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
    %workgroup_id_x = hal.interface.workgroup.id[0] : index
    %workgroup_count_x = hal.interface.workgroup.count[0] : index
    %workgroup_id_y = hal.interface.workgroup.id[1] : index
    %workgroup_count_y = hal.interface.workgroup.count[1] : index
    %9 = affine.apply affine_map<()[s0] -> (s0 * 144)>()[%workgroup_id_y]
    %10 = affine.apply affine_map<()[s0] -> (s0 * 144)>()[%workgroup_count_y]
    %11 = affine.apply affine_map<()[s0] -> (s0 * 48)>()[%workgroup_id_x]
    %12 = affine.apply affine_map<()[s0] -> (s0 * 48)>()[%workgroup_count_x]
    scf.for %arg0 = %9 to %c1089 step %10 {
      %13 = affine.min affine_map<(d0) -> (-d0 + 1089, 144)>(%arg0)
      %14 = affine.apply affine_map<(d0) -> (d0 ceildiv 16)>(%13)
      %15 = affine.apply affine_map<(d0) -> (d0 floordiv 16)>(%arg0)
      scf.for %arg1 = %11 to %c480 step %12 {
        %subview = memref.subview %8[%arg0, %arg1] [%13, 48] [1, 1] : memref<1089x480xf32, strided<[480, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<?x48xf32, strided<[480, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
        %16 = affine.apply affine_map<(d0) -> (d0 floordiv 16)>(%arg1)
        %subview_2 = memref.subview %6[%15, %16, 0, 0] [%14, 3, 16, 16] [1, 1, 1, 1] : memref<69x30x16x16xf32, strided<[7680, 256, 16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<?x3x16x16xf32, strided<[7680, 256, 16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
        %subview_3 = memref.subview %7[%arg1] [48] [1] : memref<480xf32, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<48xf32, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
        %17 = scf.for %arg2 = %c0 to %13 step %c16 iter_args(%arg3 = %subview) -> (memref<?x48xf32, strided<[480, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) {
          %18 = affine.min affine_map<(d0)[s0] -> (-d0 + s0, 16)>(%arg2)[%13]
          %19 = affine.apply affine_map<(d0) -> (d0 floordiv 16)>(%arg2)
          %20 = scf.for %arg4 = %c0 to %c48 step %c16 iter_args(%arg5 = %arg3) -> (memref<?x48xf32, strided<[480, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) {
            %21 = affine.apply affine_map<(d0) -> (d0 floordiv 16)>(%arg4)
            %subview_5 = memref.subview %subview_2[%19, %21, 0, 0] [1, 1, 16, 16] [1, 1, 1, 1] : memref<?x3x16x16xf32, strided<[7680, 256, 16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1x1x16x16xf32, strided<[7680, 256, 16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
            %subview_6 = memref.subview %arg5[%arg2, %arg4] [%18, 16] [1, 1] : memref<?x48xf32, strided<[480, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<?x16xf32, strided<[480, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
            %subview_7 = memref.subview %subview_5[0, 0, 0, 0] [1, 1, 16, 16] [1, 1, 1, 1] : memref<1x1x16x16xf32, strided<[7680, 256, 16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<16x16xf32, strided<[16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
            %subview_8 = memref.subview %subview_7[0, 0] [%18, 16] [1, 1] : memref<16x16xf32, strided<[16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<?x16xf32, strided<[16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
            linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%subview_8 : memref<?x16xf32, strided<[16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) outs(%subview_6 : memref<?x16xf32, strided<[480, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) {
            ^bb0(%in: f32, %out: f32):
              linalg.yield %in : f32
            }
            %22 = vector.transfer_read %subview_3[%arg4], %cst_1 {in_bounds = [true]} : memref<48xf32, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<16xf32>
            %23 = vector.broadcast %22 : vector<16xf32> to vector<16x16xf32>
            %24 = vector.create_mask %18, %c16 : vector<16x16xi1>
            %25 = vector.transfer_read %subview_6[%c0, %c0], %cst_1, %24 {in_bounds = [true, true]} : memref<?x16xf32, strided<[480, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<16x16xf32>
            %26 = arith.addf %25, %23 : vector<16x16xf32>
            %27 = arith.minimumf %26, %cst_0 : vector<16x16xf32>
            %28 = arith.maximumf %27, %cst : vector<16x16xf32>
            vector.transfer_write %28, %subview_6[%c0, %c0], %24 {in_bounds = [true, true]} : vector<16x16xf32>, memref<?x16xf32, strided<[480, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
            %subview_9 = memref.subview %arg5[%arg2, %arg4] [%18, 16] [1, 1] : memref<?x48xf32, strided<[480, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<?x16xf32, strided<[480, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
            linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%subview_6 : memref<?x16xf32, strided<[480, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) outs(%subview_9 : memref<?x16xf32, strided<[480, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) {
            ^bb0(%in: f32, %out: f32):
              linalg.yield %in : f32
            }
            scf.yield %arg5 : memref<?x48xf32, strided<[480, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
          }
          scf.yield %20 : memref<?x48xf32, strided<[480, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
        }
        %subview_4 = memref.subview %8[%arg0, %arg1] [%13, 48] [1, 1] : memref<1089x480xf32, strided<[480, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<?x48xf32, strided<[480, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
        linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%17 : memref<?x48xf32, strided<[480, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) outs(%subview_4 : memref<?x48xf32, strided<[480, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) {
        ^bb0(%in: f32, %out: f32):
          linalg.yield %in : f32
        }
      }
    }
    return
  }
}
pzread commented 6 months ago

I think the https://github.com/openxla/iree/issues/16772#issuecomment-2043448959 is hitting the issue mentioned in https://github.com/openxla/iree/issues/16772#issuecomment-1999030341, which we fold extract/insert_slice too much.

I'm not so familiar with bufferization, @matthias-springer do you have any suggestions on this issue?