Open benvanik opened 7 months ago
I changed the LinalgExt getDimValue helper to createOrFold (as tensor dialect and others do) and it at least spits out reasonable IR:
// -----// IR Dump After FormDispatchRegions (iree-flow-form-dispatch-regions) //----- //
mlir-asm-printer: Verifying operation: util.func
util.func public @matmul_accumulate_4x4xi8_times_4x4xi8_into_4x4xi32(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub, iree.reflection = {iree.abi.declaration = "sync func @matmul_accumulate_4x4xi8_times_4x4xi8_into_4x4xi32(%input0: tensor<4x4xi8>, %input1: tensor<4x4xi8>, %input2: tensor<4x4xi32>) -> (%output0: tensor<4x4xi32>)"}} {
%c0_i32 = arith.constant 0 : i32
%c0_i8 = arith.constant 0 : i8
%0 = hal.tensor.import %arg0 "input0" : !hal.buffer_view -> tensor<4x4xi8>
%1 = hal.tensor.import %arg1 "input1" : !hal.buffer_view -> tensor<4x4xi8>
%2 = hal.tensor.import %arg2 "input2" : !hal.buffer_view -> tensor<4x4xi32>
%3 = tensor.empty() : tensor<16x16xi8>
%4 = linalg.fill ins(%c0_i8 : i8) outs(%3 : tensor<16x16xi8>) -> tensor<16x16xi8>
%inserted_slice = tensor.insert_slice %0 into %4[0, 0] [4, 4] [1, 1] : tensor<4x4xi8> into tensor<16x16xi8>
%cast = tensor.cast %inserted_slice : tensor<16x16xi8> to tensor<?x?xi8>
%c0 = arith.constant 0 : index
%c16 = arith.constant 16 : index
%c1 = arith.constant 1 : index
%c16_0 = arith.constant 16 : index
%5 = flow.dispatch.region -> (tensor<?x?xi8, #iree_linalg_ext.encoding<role = LHS, element_types = [i8, i8, i32], original_type = tensor<4x4xi8>, matmul_narrow_M = 4 : index, matmul_narrow_N = 4 : index, user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>>{%c16, %c16_0}) {
%13 = iree_linalg_ext.set_encoding %cast : tensor<?x?xi8> -> tensor<?x?xi8, #iree_linalg_ext.encoding<role = LHS, element_types = [i8, i8, i32], original_type = tensor<4x4xi8>, matmul_narrow_M = 4 : index, matmul_narrow_N = 4 : index, user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>>
flow.return %13 : tensor<?x?xi8, #iree_linalg_ext.encoding<role = LHS, element_types = [i8, i8, i32], original_type = tensor<4x4xi8>, matmul_narrow_M = 4 : index, matmul_narrow_N = 4 : index, user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>>
}
%inserted_slice_1 = tensor.insert_slice %1 into %4[0, 0] [4, 4] [1, 1] : tensor<4x4xi8> into tensor<16x16xi8>
%cast_2 = tensor.cast %inserted_slice_1 : tensor<16x16xi8> to tensor<?x?xi8>
%c0_3 = arith.constant 0 : index
%c16_4 = arith.constant 16 : index
%c1_5 = arith.constant 1 : index
%c16_6 = arith.constant 16 : index
%6 = flow.dispatch.region -> (tensor<?x?xi8, #iree_linalg_ext.encoding<role = RHS, element_types = [i8, i8, i32], original_type = tensor<4x4xi8>, matmul_narrow_M = 4 : index, matmul_narrow_N = 4 : index, user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>>{%c16_4, %c16_6}) {
%13 = iree_linalg_ext.set_encoding %cast_2 : tensor<?x?xi8> -> tensor<?x?xi8, #iree_linalg_ext.encoding<role = RHS, element_types = [i8, i8, i32], original_type = tensor<4x4xi8>, matmul_narrow_M = 4 : index, matmul_narrow_N = 4 : index, user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>>
flow.return %13 : tensor<?x?xi8, #iree_linalg_ext.encoding<role = RHS, element_types = [i8, i8, i32], original_type = tensor<4x4xi8>, matmul_narrow_M = 4 : index, matmul_narrow_N = 4 : index, user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>>
}
%7 = tensor.empty() : tensor<16x16xi32>
%8 = linalg.fill ins(%c0_i32 : i32) outs(%7 : tensor<16x16xi32>) -> tensor<16x16xi32>
%inserted_slice_7 = tensor.insert_slice %2 into %8[0, 0] [4, 4] [1, 1] : tensor<4x4xi32> into tensor<16x16xi32>
%cast_8 = tensor.cast %inserted_slice_7 : tensor<16x16xi32> to tensor<?x?xi32>
%c0_9 = arith.constant 0 : index
%c16_10 = arith.constant 16 : index
%c1_11 = arith.constant 1 : index
%c16_12 = arith.constant 16 : index
%c0_13 = arith.constant 0 : index
%c16_14 = arith.constant 16 : index
%c1_15 = arith.constant 1 : index
%c16_16 = arith.constant 16 : index
%9 = flow.dispatch.region -> (tensor<?x?xi32, #iree_linalg_ext.encoding<role = RESULT, element_types = [i8, i8, i32], original_type = tensor<4x4xi32>, matmul_narrow_M = 4 : index, matmul_narrow_N = 4 : index, user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>>{%c16_14, %c16_16}) {
%13 = iree_linalg_ext.set_encoding %cast_8 : tensor<?x?xi32> -> tensor<?x?xi32, #iree_linalg_ext.encoding<role = RESULT, element_types = [i8, i8, i32], original_type = tensor<4x4xi32>, matmul_narrow_M = 4 : index, matmul_narrow_N = 4 : index, user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>>
flow.return %13 : tensor<?x?xi32, #iree_linalg_ext.encoding<role = RESULT, element_types = [i8, i8, i32], original_type = tensor<4x4xi32>, matmul_narrow_M = 4 : index, matmul_narrow_N = 4 : index, user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>>
}
%10 = flow.dispatch.region -> (tensor<?x?xi32, #iree_linalg_ext.encoding<role = RESULT, element_types = [i8, i8, i32], original_type = tensor<4x4xi32>, matmul_narrow_M = 4 : index, matmul_narrow_N = 4 : index, user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>>{%c16_10, %c16_12}) {
%13 = linalg.matmul ins(%5, %6 : tensor<?x?xi8, #iree_linalg_ext.encoding<role = LHS, element_types = [i8, i8, i32], original_type = tensor<4x4xi8>, matmul_narrow_M = 4 : index, matmul_narrow_N = 4 : index, user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>>, tensor<?x?xi8, #iree_linalg_ext.encoding<role = RHS, element_types = [i8, i8, i32], original_type = tensor<4x4xi8>, matmul_narrow_M = 4 : index, matmul_narrow_N = 4 : index, user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>>) outs(%9 : tensor<?x?xi32, #iree_linalg_ext.encoding<role = RESULT, element_types = [i8, i8, i32], original_type = tensor<4x4xi32>, matmul_narrow_M = 4 : index, matmul_narrow_N = 4 : index, user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>>) -> tensor<?x?xi32, #iree_linalg_ext.encoding<role = RESULT, element_types = [i8, i8, i32], original_type = tensor<4x4xi32>, matmul_narrow_M = 4 : index, matmul_narrow_N = 4 : index, user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>>
flow.return %13 : tensor<?x?xi32, #iree_linalg_ext.encoding<role = RESULT, element_types = [i8, i8, i32], original_type = tensor<4x4xi32>, matmul_narrow_M = 4 : index, matmul_narrow_N = 4 : index, user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>>
}
%11 = flow.dispatch.region -> (tensor<4x4xi32>) {
%13 = iree_linalg_ext.unset_encoding %10 : tensor<?x?xi32, #iree_linalg_ext.encoding<role = RESULT, element_types = [i8, i8, i32], original_type = tensor<4x4xi32>, matmul_narrow_M = 4 : index, matmul_narrow_N = 4 : index, user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>> -> tensor<?x?xi32>
%extracted_slice = tensor.extract_slice %13[0, 0] [4, 4] [1, 1] : tensor<?x?xi32> to tensor<4x4xi32>
flow.return %extracted_slice : tensor<4x4xi32>
}
%12 = hal.tensor.export %11 "output0" : tensor<4x4xi32> -> !hal.buffer_view
util.return %12 : !hal.buffer_view
}
It's still creating the tensor.dim ops, folding them, and then erasing them as many times but the IR is better on the way out.
Ok, I thought this could me made better, but really needs a little bit of reworking to use a special listener that the reify
methods can query to see if the op already exists and reuse it. Will take a bit of work upstream to plumb this all through. Should really move the Listener that is in FormDispatchRegions
upstream and use it when available.
The only real way to fix this is to modify the interface method in ReifyRankedShapeTypeInterface
. The reifyResultDim
method should take an extra (optional) Listener
object as argument. When provided, the implementation should use the Listener to query for existing tensor.dim
operations instead of creating new ones.
This doesn't need to be a world rewrite. Can be added incrementallly.
The
simplifyDims
method inFormDispatchRegions.cpp
is doing something wrong with dynamic dims. It seems to callreifyDynamicResultDims
on eachtensor.dim
op which in dynamic cases inserts a newtensor.dim
op. This leads to an explosion of IR that is expensive to create.simplifyDims
seems to be load-bearing instead of just simplifying dims, as returning early from it causes op domination issues. I can't figure out exactly what it's doing but I suspect that instead of callingreifyDynamicResultDims
it should be moving the ops around - things only work today because it happens to be recreating each dim op at a new location.The input is small so I think us creating 300+ tensor.dim ops is the thing we should be able to fix.
Input IR:
Produces this when iree-opt is used for `"--pass-pipeline=builtin.module(util.func(iree-flow-form-dispatch-regions))":