iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.79k stars 603 forks source link

FormDispatchRegions inserts thousands of redundant tensor.dim ops with dynamic dims. #16683

Open benvanik opened 7 months ago

benvanik commented 7 months ago

The simplifyDims method in FormDispatchRegions.cpp is doing something wrong with dynamic dims. It seems to call reifyDynamicResultDims on each tensor.dim op which in dynamic cases inserts a new tensor.dim op. This leads to an explosion of IR that is expensive to create.

simplifyDims seems to be load-bearing instead of just simplifying dims, as returning early from it causes op domination issues. I can't figure out exactly what it's doing but I suspect that instead of calling reifyDynamicResultDims it should be moving the ops around - things only work today because it happens to be recreating each dim op at a new location.

The input is small so I think us creating 300+ tensor.dim ops is the thing we should be able to fix.

Input IR:

// -----// IR Dump Before FormDispatchRegions (iree-flow-form-dispatch-regions) //----- //
util.func public @matmul_accumulate_4x4xi8_times_4x4xi8_into_4x4xi32(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub, iree.reflection = {iree.abi.declaration = "sync func @matmul_accumulate_4x4xi8_times_4x4xi8_into_4x4xi32(%input0: tensor<4x4xi8>, %input1: tensor<4x4xi8>, %input2: tensor<4x4xi32>) -> (%output0: tensor<4x4xi32>)"}} {
  %c0_i32 = arith.constant 0 : i32
  %c0_i8 = arith.constant 0 : i8
  %0 = hal.tensor.import %arg0 "input0" : !hal.buffer_view -> tensor<4x4xi8>
  %1 = hal.tensor.import %arg1 "input1" : !hal.buffer_view -> tensor<4x4xi8>
  %2 = hal.tensor.import %arg2 "input2" : !hal.buffer_view -> tensor<4x4xi32>
  %3 = tensor.empty() : tensor<16x16xi8>
  %4 = linalg.fill ins(%c0_i8 : i8) outs(%3 : tensor<16x16xi8>) -> tensor<16x16xi8>
  %inserted_slice = tensor.insert_slice %0 into %4[0, 0] [4, 4] [1, 1] : tensor<4x4xi8> into tensor<16x16xi8>
  %cast = tensor.cast %inserted_slice : tensor<16x16xi8> to tensor<?x?xi8>
  %5 = iree_linalg_ext.set_encoding %cast : tensor<?x?xi8> -> tensor<?x?xi8, #iree_linalg_ext.encoding<role =  LHS, element_types = [i8, i8, i32], original_type = tensor<4x4xi8>, matmul_narrow_M = 4 : index, matmul_narrow_N = 4 : index, user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>>
  %inserted_slice_0 = tensor.insert_slice %1 into %4[0, 0] [4, 4] [1, 1] : tensor<4x4xi8> into tensor<16x16xi8>
  %cast_1 = tensor.cast %inserted_slice_0 : tensor<16x16xi8> to tensor<?x?xi8>
  %6 = iree_linalg_ext.set_encoding %cast_1 : tensor<?x?xi8> -> tensor<?x?xi8, #iree_linalg_ext.encoding<role =  RHS, element_types = [i8, i8, i32], original_type = tensor<4x4xi8>, matmul_narrow_M = 4 : index, matmul_narrow_N = 4 : index, user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>>
  %7 = tensor.empty() : tensor<16x16xi32>
  %8 = linalg.fill ins(%c0_i32 : i32) outs(%7 : tensor<16x16xi32>) -> tensor<16x16xi32>
  %inserted_slice_2 = tensor.insert_slice %2 into %8[0, 0] [4, 4] [1, 1] : tensor<4x4xi32> into tensor<16x16xi32>
  %cast_3 = tensor.cast %inserted_slice_2 : tensor<16x16xi32> to tensor<?x?xi32>
  %9 = iree_linalg_ext.set_encoding %cast_3 : tensor<?x?xi32> -> tensor<?x?xi32, #iree_linalg_ext.encoding<role =  RESULT, element_types = [i8, i8, i32], original_type = tensor<4x4xi32>, matmul_narrow_M = 4 : index, matmul_narrow_N = 4 : index, user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>>
  %10 = linalg.matmul ins(%5, %6 : tensor<?x?xi8, #iree_linalg_ext.encoding<role =  LHS, element_types = [i8, i8, i32], original_type = tensor<4x4xi8>, matmul_narrow_M = 4 : index, matmul_narrow_N = 4 : index, user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>>, tensor<?x?xi8, #iree_linalg_ext.encoding<role =  RHS, element_types = [i8, i8, i32], original_type = tensor<4x4xi8>, matmul_narrow_M = 4 : index, matmul_narrow_N = 4 : index, user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>>) outs(%9 : tensor<?x?xi32, #iree_linalg_ext.encoding<role =  RESULT, element_types = [i8, i8, i32], original_type = tensor<4x4xi32>, matmul_narrow_M = 4 : index, matmul_narrow_N = 4 : index, user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>>) -> tensor<?x?xi32, #iree_linalg_ext.encoding<role =  RESULT, element_types = [i8, i8, i32], original_type = tensor<4x4xi32>, matmul_narrow_M = 4 : index, matmul_narrow_N = 4 : index, user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>>
  %11 = iree_linalg_ext.unset_encoding %10 : tensor<?x?xi32, #iree_linalg_ext.encoding<role =  RESULT, element_types = [i8, i8, i32], original_type = tensor<4x4xi32>, matmul_narrow_M = 4 : index, matmul_narrow_N = 4 : index, user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>> -> tensor<?x?xi32>
  %extracted_slice = tensor.extract_slice %11[0, 0] [4, 4] [1, 1] : tensor<?x?xi32> to tensor<4x4xi32>
  %12 = hal.tensor.export %extracted_slice "output0" : tensor<4x4xi32> -> !hal.buffer_view
  util.return %12 : !hal.buffer_view
}

Produces this when iree-opt is used for `"--pass-pipeline=builtin.module(util.func(iree-flow-form-dispatch-regions))":

// -----// IR Dump After FormDispatchRegions (iree-flow-form-dispatch-regions) //----- //
mlir-asm-printer: Verifying operation: util.func
util.func public @matmul_accumulate_4x4xi8_times_4x4xi8_into_4x4xi32(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub, iree.reflection = {iree.abi.declaration = "sync func @matmul_accumulate_4x4xi8_times_4x4xi8_into_4x4xi32(%input0: tensor<4x4xi8>, %input1: tensor<4x4xi8>, %input2: tensor<4x4xi32>) -> (%output0: tensor<4x4xi32>)"}} {
  %c0_i32 = arith.constant 0 : i32
  %c0_i8 = arith.constant 0 : i8
  %0 = hal.tensor.import %arg0 "input0" : !hal.buffer_view -> tensor<4x4xi8>
  %1 = hal.tensor.import %arg1 "input1" : !hal.buffer_view -> tensor<4x4xi8>
  %2 = hal.tensor.import %arg2 "input2" : !hal.buffer_view -> tensor<4x4xi32>
  %3 = tensor.empty() : tensor<16x16xi8>
  %4 = linalg.fill ins(%c0_i8 : i8) outs(%3 : tensor<16x16xi8>) -> tensor<16x16xi8>
  %inserted_slice = tensor.insert_slice %0 into %4[0, 0] [4, 4] [1, 1] : tensor<4x4xi8> into tensor<16x16xi8>
  %cast = tensor.cast %inserted_slice : tensor<16x16xi8> to tensor<?x?xi8>
  %c0 = arith.constant 0 : index
  %dim = tensor.dim %cast, %c0 : tensor<?x?xi8>
  %c1 = arith.constant 1 : index
  %dim_0 = tensor.dim %cast, %c1 : tensor<?x?xi8>
  %c0_1 = arith.constant 0 : index
  %dim_2 = tensor.dim %cast, %c0_1 : tensor<?x?xi8>
  %c1_3 = arith.constant 1 : index
  %dim_4 = tensor.dim %cast, %c1_3 : tensor<?x?xi8>
  %c0_5 = arith.constant 0 : index
  %c1_6 = arith.constant 1 : index
  %5 = flow.dispatch.region -> (tensor<?x?xi8, #iree_linalg_ext.encoding<role =  LHS, element_types = [i8, i8, i32], original_type = tensor<4x4xi8>, matmul_narrow_M = 4 : index, matmul_narrow_N = 4 : index, user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>>{%dim_2, %dim_0}) {
    %13 = iree_linalg_ext.set_encoding %cast : tensor<?x?xi8> -> tensor<?x?xi8, #iree_linalg_ext.encoding<role =  LHS, element_types = [i8, i8, i32], original_type = tensor<4x4xi8>, matmul_narrow_M = 4 : index, matmul_narrow_N = 4 : index, user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>>
    flow.return %13 : tensor<?x?xi8, #iree_linalg_ext.encoding<role =  LHS, element_types = [i8, i8, i32], original_type = tensor<4x4xi8>, matmul_narrow_M = 4 : index, matmul_narrow_N = 4 : index, user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>>
  }
  %inserted_slice_7 = tensor.insert_slice %1 into %4[0, 0] [4, 4] [1, 1] : tensor<4x4xi8> into tensor<16x16xi8>
  %cast_8 = tensor.cast %inserted_slice_7 : tensor<16x16xi8> to tensor<?x?xi8>
  %c0_9 = arith.constant 0 : index
  %dim_10 = tensor.dim %cast_8, %c0_9 : tensor<?x?xi8>
  %c1_11 = arith.constant 1 : index
  %dim_12 = tensor.dim %cast_8, %c1_11 : tensor<?x?xi8>
  %c0_13 = arith.constant 0 : index
  %dim_14 = tensor.dim %cast_8, %c0_13 : tensor<?x?xi8>
  %c1_15 = arith.constant 1 : index
  %dim_16 = tensor.dim %cast_8, %c1_15 : tensor<?x?xi8>
  %c0_17 = arith.constant 0 : index
  %dim_18 = tensor.dim %cast_8, %c0_17 : tensor<?x?xi8>
  %c1_19 = arith.constant 1 : index
  %dim_20 = tensor.dim %cast_8, %c1_19 : tensor<?x?xi8>
  %c0_21 = arith.constant 0 : index
  %dim_22 = tensor.dim %cast_8, %c0_21 : tensor<?x?xi8>
  %c1_23 = arith.constant 1 : index
  %dim_24 = tensor.dim %cast_8, %c1_23 : tensor<?x?xi8>
  %c0_25 = arith.constant 0 : index
  %dim_26 = tensor.dim %cast_8, %c0_25 : tensor<?x?xi8>
  %c1_27 = arith.constant 1 : index
  %dim_28 = tensor.dim %cast_8, %c1_27 : tensor<?x?xi8>
  %c0_29 = arith.constant 0 : index
  %dim_30 = tensor.dim %cast_8, %c0_29 : tensor<?x?xi8>
  %c1_31 = arith.constant 1 : index
  %dim_32 = tensor.dim %cast_8, %c1_31 : tensor<?x?xi8>
  %c0_33 = arith.constant 0 : index
  %dim_34 = tensor.dim %cast_8, %c0_33 : tensor<?x?xi8>
  %c1_35 = arith.constant 1 : index
  %dim_36 = tensor.dim %cast_8, %c1_35 : tensor<?x?xi8>
  %c0_37 = arith.constant 0 : index
  %dim_38 = tensor.dim %cast_8, %c0_37 : tensor<?x?xi8>
  %c1_39 = arith.constant 1 : index
  %dim_40 = tensor.dim %cast_8, %c1_39 : tensor<?x?xi8>
  %c0_41 = arith.constant 0 : index
  %c1_42 = arith.constant 1 : index
  %c0_43 = arith.constant 0 : index
  %c1_44 = arith.constant 1 : index
  %c0_45 = arith.constant 0 : index
  %c1_46 = arith.constant 1 : index
  %c0_47 = arith.constant 0 : index
  %c1_48 = arith.constant 1 : index
  %c0_49 = arith.constant 0 : index
  %c1_50 = arith.constant 1 : index
  %c0_51 = arith.constant 0 : index
  %c1_52 = arith.constant 1 : index
  %c0_53 = arith.constant 0 : index
  %c1_54 = arith.constant 1 : index
  %6 = flow.dispatch.region -> (tensor<?x?xi8, #iree_linalg_ext.encoding<role =  RHS, element_types = [i8, i8, i32], original_type = tensor<4x4xi8>, matmul_narrow_M = 4 : index, matmul_narrow_N = 4 : index, user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>>{%dim_14, %dim_20}) {
    %13 = iree_linalg_ext.set_encoding %cast_8 : tensor<?x?xi8> -> tensor<?x?xi8, #iree_linalg_ext.encoding<role =  RHS, element_types = [i8, i8, i32], original_type = tensor<4x4xi8>, matmul_narrow_M = 4 : index, matmul_narrow_N = 4 : index, user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>>
    flow.return %13 : tensor<?x?xi8, #iree_linalg_ext.encoding<role =  RHS, element_types = [i8, i8, i32], original_type = tensor<4x4xi8>, matmul_narrow_M = 4 : index, matmul_narrow_N = 4 : index, user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>>
  }
  %7 = tensor.empty() : tensor<16x16xi32>
  %8 = linalg.fill ins(%c0_i32 : i32) outs(%7 : tensor<16x16xi32>) -> tensor<16x16xi32>
  %inserted_slice_55 = tensor.insert_slice %2 into %8[0, 0] [4, 4] [1, 1] : tensor<4x4xi32> into tensor<16x16xi32>
  %cast_56 = tensor.cast %inserted_slice_55 : tensor<16x16xi32> to tensor<?x?xi32>
  %c0_57 = arith.constant 0 : index
  %dim_58 = tensor.dim %cast_56, %c0_57 : tensor<?x?xi32>
  %c1_59 = arith.constant 1 : index
  %dim_60 = tensor.dim %cast_56, %c1_59 : tensor<?x?xi32>
  %c0_61 = arith.constant 0 : index
  %dim_62 = tensor.dim %cast_56, %c0_61 : tensor<?x?xi32>
  %c1_63 = arith.constant 1 : index
  %dim_64 = tensor.dim %cast_56, %c1_63 : tensor<?x?xi32>
  %c0_65 = arith.constant 0 : index
  %dim_66 = tensor.dim %cast_56, %c0_65 : tensor<?x?xi32>
  %c1_67 = arith.constant 1 : index
  %dim_68 = tensor.dim %cast_56, %c1_67 : tensor<?x?xi32>
  %c0_69 = arith.constant 0 : index
  %dim_70 = tensor.dim %cast_56, %c0_69 : tensor<?x?xi32>
  %c1_71 = arith.constant 1 : index
  %dim_72 = tensor.dim %cast_56, %c1_71 : tensor<?x?xi32>
  %c0_73 = arith.constant 0 : index
  %dim_74 = tensor.dim %cast_56, %c0_73 : tensor<?x?xi32>
  %c1_75 = arith.constant 1 : index
  %dim_76 = tensor.dim %cast_56, %c1_75 : tensor<?x?xi32>
  %c0_77 = arith.constant 0 : index
  %dim_78 = tensor.dim %cast_56, %c0_77 : tensor<?x?xi32>
  %c1_79 = arith.constant 1 : index
  %dim_80 = tensor.dim %cast_56, %c1_79 : tensor<?x?xi32>
  %c0_81 = arith.constant 0 : index
  %dim_82 = tensor.dim %cast_56, %c0_81 : tensor<?x?xi32>
  %c1_83 = arith.constant 1 : index
  %dim_84 = tensor.dim %cast_56, %c1_83 : tensor<?x?xi32>
  %c0_85 = arith.constant 0 : index
  %dim_86 = tensor.dim %cast_56, %c0_85 : tensor<?x?xi32>
  %c1_87 = arith.constant 1 : index
  %dim_88 = tensor.dim %cast_56, %c1_87 : tensor<?x?xi32>
  %c0_89 = arith.constant 0 : index
  %dim_90 = tensor.dim %cast_56, %c0_89 : tensor<?x?xi32>
  %c1_91 = arith.constant 1 : index
  %dim_92 = tensor.dim %cast_56, %c1_91 : tensor<?x?xi32>
  %c0_93 = arith.constant 0 : index
  %dim_94 = tensor.dim %cast_56, %c0_93 : tensor<?x?xi32>
  %c1_95 = arith.constant 1 : index
  %dim_96 = tensor.dim %cast_56, %c1_95 : tensor<?x?xi32>
  %c0_97 = arith.constant 0 : index
  %dim_98 = tensor.dim %cast_56, %c0_97 : tensor<?x?xi32>
  %c1_99 = arith.constant 1 : index
  %dim_100 = tensor.dim %cast_56, %c1_99 : tensor<?x?xi32>
  %c0_101 = arith.constant 0 : index
  %dim_102 = tensor.dim %cast_56, %c0_101 : tensor<?x?xi32>
  %c1_103 = arith.constant 1 : index
  %dim_104 = tensor.dim %cast_56, %c1_103 : tensor<?x?xi32>
  %c0_105 = arith.constant 0 : index
  %dim_106 = tensor.dim %cast_56, %c0_105 : tensor<?x?xi32>
  %c1_107 = arith.constant 1 : index
  %dim_108 = tensor.dim %cast_56, %c1_107 : tensor<?x?xi32>
  %c0_109 = arith.constant 0 : index
  %dim_110 = tensor.dim %cast_56, %c0_109 : tensor<?x?xi32>
  %c1_111 = arith.constant 1 : index
  %dim_112 = tensor.dim %cast_56, %c1_111 : tensor<?x?xi32>
  %c0_113 = arith.constant 0 : index
  %dim_114 = tensor.dim %cast_56, %c0_113 : tensor<?x?xi32>
  %c1_115 = arith.constant 1 : index
  %dim_116 = tensor.dim %cast_56, %c1_115 : tensor<?x?xi32>
  %c0_117 = arith.constant 0 : index
  %dim_118 = tensor.dim %cast_56, %c0_117 : tensor<?x?xi32>
  %c1_119 = arith.constant 1 : index
  %dim_120 = tensor.dim %cast_56, %c1_119 : tensor<?x?xi32>
  %c0_121 = arith.constant 0 : index
  %dim_122 = tensor.dim %cast_56, %c0_121 : tensor<?x?xi32>
  %c1_123 = arith.constant 1 : index
  %dim_124 = tensor.dim %cast_56, %c1_123 : tensor<?x?xi32>
  %c0_125 = arith.constant 0 : index
  %dim_126 = tensor.dim %cast_56, %c0_125 : tensor<?x?xi32>
  %c1_127 = arith.constant 1 : index
  %dim_128 = tensor.dim %cast_56, %c1_127 : tensor<?x?xi32>
  %c0_129 = arith.constant 0 : index
  %dim_130 = tensor.dim %cast_56, %c0_129 : tensor<?x?xi32>
  %c1_131 = arith.constant 1 : index
  %dim_132 = tensor.dim %cast_56, %c1_131 : tensor<?x?xi32>
  %c0_133 = arith.constant 0 : index
  %dim_134 = tensor.dim %cast_56, %c0_133 : tensor<?x?xi32>
  %c1_135 = arith.constant 1 : index
  %dim_136 = tensor.dim %cast_56, %c1_135 : tensor<?x?xi32>
  %c0_137 = arith.constant 0 : index
  %dim_138 = tensor.dim %cast_56, %c0_137 : tensor<?x?xi32>
  %c1_139 = arith.constant 1 : index
  %dim_140 = tensor.dim %cast_56, %c1_139 : tensor<?x?xi32>
  %c0_141 = arith.constant 0 : index
  %dim_142 = tensor.dim %cast_56, %c0_141 : tensor<?x?xi32>
  %c1_143 = arith.constant 1 : index
  %dim_144 = tensor.dim %cast_56, %c1_143 : tensor<?x?xi32>
  %c0_145 = arith.constant 0 : index
  %dim_146 = tensor.dim %cast_56, %c0_145 : tensor<?x?xi32>
  %c1_147 = arith.constant 1 : index
  %dim_148 = tensor.dim %cast_56, %c1_147 : tensor<?x?xi32>
  %c0_149 = arith.constant 0 : index
  %dim_150 = tensor.dim %cast_56, %c0_149 : tensor<?x?xi32>
  %c1_151 = arith.constant 1 : index
  %dim_152 = tensor.dim %cast_56, %c1_151 : tensor<?x?xi32>
  %c0_153 = arith.constant 0 : index
  %dim_154 = tensor.dim %cast_56, %c0_153 : tensor<?x?xi32>
  %c1_155 = arith.constant 1 : index
  %dim_156 = tensor.dim %cast_56, %c1_155 : tensor<?x?xi32>
  %c0_157 = arith.constant 0 : index
  %dim_158 = tensor.dim %cast_56, %c0_157 : tensor<?x?xi32>
  %c1_159 = arith.constant 1 : index
  %dim_160 = tensor.dim %cast_56, %c1_159 : tensor<?x?xi32>
  %c0_161 = arith.constant 0 : index
  %dim_162 = tensor.dim %cast_56, %c0_161 : tensor<?x?xi32>
  %c1_163 = arith.constant 1 : index
  %dim_164 = tensor.dim %cast_56, %c1_163 : tensor<?x?xi32>
  %c0_165 = arith.constant 0 : index
  %dim_166 = tensor.dim %cast_56, %c0_165 : tensor<?x?xi32>
  %c1_167 = arith.constant 1 : index
  %dim_168 = tensor.dim %cast_56, %c1_167 : tensor<?x?xi32>
  %c0_169 = arith.constant 0 : index
  %dim_170 = tensor.dim %cast_56, %c0_169 : tensor<?x?xi32>
  %c1_171 = arith.constant 1 : index
  %dim_172 = tensor.dim %cast_56, %c1_171 : tensor<?x?xi32>
  %c0_173 = arith.constant 0 : index
  %dim_174 = tensor.dim %cast_56, %c0_173 : tensor<?x?xi32>
  %c1_175 = arith.constant 1 : index
  %dim_176 = tensor.dim %cast_56, %c1_175 : tensor<?x?xi32>
  %c0_177 = arith.constant 0 : index
  %dim_178 = tensor.dim %cast_56, %c0_177 : tensor<?x?xi32>
  %c1_179 = arith.constant 1 : index
  %dim_180 = tensor.dim %cast_56, %c1_179 : tensor<?x?xi32>
  %c0_181 = arith.constant 0 : index
  %dim_182 = tensor.dim %cast_56, %c0_181 : tensor<?x?xi32>
  %c1_183 = arith.constant 1 : index
  %dim_184 = tensor.dim %cast_56, %c1_183 : tensor<?x?xi32>
  %c0_185 = arith.constant 0 : index
  %dim_186 = tensor.dim %cast_56, %c0_185 : tensor<?x?xi32>
  %c1_187 = arith.constant 1 : index
  %dim_188 = tensor.dim %cast_56, %c1_187 : tensor<?x?xi32>
  %c0_189 = arith.constant 0 : index
  %dim_190 = tensor.dim %cast_56, %c0_189 : tensor<?x?xi32>
  %c1_191 = arith.constant 1 : index
  %dim_192 = tensor.dim %cast_56, %c1_191 : tensor<?x?xi32>
  %c0_193 = arith.constant 0 : index
  %dim_194 = tensor.dim %cast_56, %c0_193 : tensor<?x?xi32>
  %c1_195 = arith.constant 1 : index
  %dim_196 = tensor.dim %cast_56, %c1_195 : tensor<?x?xi32>
  %c0_197 = arith.constant 0 : index
  %dim_198 = tensor.dim %cast_56, %c0_197 : tensor<?x?xi32>
  %c1_199 = arith.constant 1 : index
  %dim_200 = tensor.dim %cast_56, %c1_199 : tensor<?x?xi32>
  %c0_201 = arith.constant 0 : index
  %dim_202 = tensor.dim %cast_56, %c0_201 : tensor<?x?xi32>
  %c1_203 = arith.constant 1 : index
  %dim_204 = tensor.dim %cast_56, %c1_203 : tensor<?x?xi32>
  %c0_205 = arith.constant 0 : index
  %dim_206 = tensor.dim %cast_56, %c0_205 : tensor<?x?xi32>
  %c1_207 = arith.constant 1 : index
  %dim_208 = tensor.dim %cast_56, %c1_207 : tensor<?x?xi32>
  %c0_209 = arith.constant 0 : index
  %dim_210 = tensor.dim %cast_56, %c0_209 : tensor<?x?xi32>
  %c1_211 = arith.constant 1 : index
  %dim_212 = tensor.dim %cast_56, %c1_211 : tensor<?x?xi32>
  %c0_213 = arith.constant 0 : index
  %dim_214 = tensor.dim %cast_56, %c0_213 : tensor<?x?xi32>
  %c1_215 = arith.constant 1 : index
  %dim_216 = tensor.dim %cast_56, %c1_215 : tensor<?x?xi32>
  %c0_217 = arith.constant 0 : index
  %dim_218 = tensor.dim %cast_56, %c0_217 : tensor<?x?xi32>
  %c1_219 = arith.constant 1 : index
  %dim_220 = tensor.dim %cast_56, %c1_219 : tensor<?x?xi32>
  %c0_221 = arith.constant 0 : index
  %dim_222 = tensor.dim %cast_56, %c0_221 : tensor<?x?xi32>
  %c1_223 = arith.constant 1 : index
  %dim_224 = tensor.dim %cast_56, %c1_223 : tensor<?x?xi32>
  %c0_225 = arith.constant 0 : index
  %dim_226 = tensor.dim %cast_56, %c0_225 : tensor<?x?xi32>
  %c1_227 = arith.constant 1 : index
  %dim_228 = tensor.dim %cast_56, %c1_227 : tensor<?x?xi32>
  %c0_229 = arith.constant 0 : index
  %dim_230 = tensor.dim %cast_56, %c0_229 : tensor<?x?xi32>
  %c1_231 = arith.constant 1 : index
  %dim_232 = tensor.dim %cast_56, %c1_231 : tensor<?x?xi32>
  %c0_233 = arith.constant 0 : index
  %dim_234 = tensor.dim %cast_56, %c0_233 : tensor<?x?xi32>
  %c1_235 = arith.constant 1 : index
  %dim_236 = tensor.dim %cast_56, %c1_235 : tensor<?x?xi32>
  %c0_237 = arith.constant 0 : index
  %dim_238 = tensor.dim %cast_56, %c0_237 : tensor<?x?xi32>
  %c1_239 = arith.constant 1 : index
  %dim_240 = tensor.dim %cast_56, %c1_239 : tensor<?x?xi32>
  %c0_241 = arith.constant 0 : index
  %dim_242 = tensor.dim %cast_56, %c0_241 : tensor<?x?xi32>
  %c1_243 = arith.constant 1 : index
  %dim_244 = tensor.dim %cast_56, %c1_243 : tensor<?x?xi32>
  %c0_245 = arith.constant 0 : index
  %dim_246 = tensor.dim %cast_56, %c0_245 : tensor<?x?xi32>
  %c1_247 = arith.constant 1 : index
  %dim_248 = tensor.dim %cast_56, %c1_247 : tensor<?x?xi32>
  %c0_249 = arith.constant 0 : index
  %dim_250 = tensor.dim %cast_56, %c0_249 : tensor<?x?xi32>
  %c1_251 = arith.constant 1 : index
  %dim_252 = tensor.dim %cast_56, %c1_251 : tensor<?x?xi32>
  %c0_253 = arith.constant 0 : index
  %dim_254 = tensor.dim %cast_56, %c0_253 : tensor<?x?xi32>
  %c1_255 = arith.constant 1 : index
  %dim_256 = tensor.dim %cast_56, %c1_255 : tensor<?x?xi32>
  %c0_257 = arith.constant 0 : index
  %dim_258 = tensor.dim %cast_56, %c0_257 : tensor<?x?xi32>
  %c1_259 = arith.constant 1 : index
  %dim_260 = tensor.dim %cast_56, %c1_259 : tensor<?x?xi32>
  %c0_261 = arith.constant 0 : index
  %dim_262 = tensor.dim %cast_56, %c0_261 : tensor<?x?xi32>
  %c1_263 = arith.constant 1 : index
  %dim_264 = tensor.dim %cast_56, %c1_263 : tensor<?x?xi32>
  %c0_265 = arith.constant 0 : index
  %dim_266 = tensor.dim %cast_56, %c0_265 : tensor<?x?xi32>
  %c1_267 = arith.constant 1 : index
  %dim_268 = tensor.dim %cast_56, %c1_267 : tensor<?x?xi32>
  %c0_269 = arith.constant 0 : index
  %dim_270 = tensor.dim %cast_56, %c0_269 : tensor<?x?xi32>
  %c1_271 = arith.constant 1 : index
  %dim_272 = tensor.dim %cast_56, %c1_271 : tensor<?x?xi32>
  %c0_273 = arith.constant 0 : index
  %dim_274 = tensor.dim %cast_56, %c0_273 : tensor<?x?xi32>
  %c1_275 = arith.constant 1 : index
  %dim_276 = tensor.dim %cast_56, %c1_275 : tensor<?x?xi32>
  %c0_277 = arith.constant 0 : index
  %dim_278 = tensor.dim %cast_56, %c0_277 : tensor<?x?xi32>
  %c1_279 = arith.constant 1 : index
  %dim_280 = tensor.dim %cast_56, %c1_279 : tensor<?x?xi32>
  %c0_281 = arith.constant 0 : index
  %dim_282 = tensor.dim %cast_56, %c0_281 : tensor<?x?xi32>
  %c1_283 = arith.constant 1 : index
  %dim_284 = tensor.dim %cast_56, %c1_283 : tensor<?x?xi32>
  %c0_285 = arith.constant 0 : index
  %dim_286 = tensor.dim %cast_56, %c0_285 : tensor<?x?xi32>
  %c1_287 = arith.constant 1 : index
  %dim_288 = tensor.dim %cast_56, %c1_287 : tensor<?x?xi32>
  %c0_289 = arith.constant 0 : index
  %dim_290 = tensor.dim %cast_56, %c0_289 : tensor<?x?xi32>
  %c1_291 = arith.constant 1 : index
  %dim_292 = tensor.dim %cast_56, %c1_291 : tensor<?x?xi32>
  %c0_293 = arith.constant 0 : index
  %dim_294 = tensor.dim %cast_56, %c0_293 : tensor<?x?xi32>
  %c1_295 = arith.constant 1 : index
  %dim_296 = tensor.dim %cast_56, %c1_295 : tensor<?x?xi32>
  %c0_297 = arith.constant 0 : index
  %dim_298 = tensor.dim %cast_56, %c0_297 : tensor<?x?xi32>
  %c1_299 = arith.constant 1 : index
  %dim_300 = tensor.dim %cast_56, %c1_299 : tensor<?x?xi32>
  %c0_301 = arith.constant 0 : index
  %dim_302 = tensor.dim %cast_56, %c0_301 : tensor<?x?xi32>
  %c1_303 = arith.constant 1 : index
  %dim_304 = tensor.dim %cast_56, %c1_303 : tensor<?x?xi32>
  %c0_305 = arith.constant 0 : index
  %dim_306 = tensor.dim %cast_56, %c0_305 : tensor<?x?xi32>
  %c1_307 = arith.constant 1 : index
  %dim_308 = tensor.dim %cast_56, %c1_307 : tensor<?x?xi32>
  %c0_309 = arith.constant 0 : index
  %dim_310 = tensor.dim %cast_56, %c0_309 : tensor<?x?xi32>
  %c1_311 = arith.constant 1 : index
  %dim_312 = tensor.dim %cast_56, %c1_311 : tensor<?x?xi32>
  %c0_313 = arith.constant 0 : index
  %dim_314 = tensor.dim %cast_56, %c0_313 : tensor<?x?xi32>
  %c1_315 = arith.constant 1 : index
  %dim_316 = tensor.dim %cast_56, %c1_315 : tensor<?x?xi32>
  %c0_317 = arith.constant 0 : index
  %dim_318 = tensor.dim %cast_56, %c0_317 : tensor<?x?xi32>
  %c1_319 = arith.constant 1 : index
  %dim_320 = tensor.dim %cast_56, %c1_319 : tensor<?x?xi32>
  %c0_321 = arith.constant 0 : index
  %dim_322 = tensor.dim %cast_56, %c0_321 : tensor<?x?xi32>
  %c1_323 = arith.constant 1 : index
  %dim_324 = tensor.dim %cast_56, %c1_323 : tensor<?x?xi32>
  %c0_325 = arith.constant 0 : index
  %dim_326 = tensor.dim %cast_56, %c0_325 : tensor<?x?xi32>
  %c1_327 = arith.constant 1 : index
  %dim_328 = tensor.dim %cast_56, %c1_327 : tensor<?x?xi32>
  %c0_329 = arith.constant 0 : index
  %dim_330 = tensor.dim %cast_56, %c0_329 : tensor<?x?xi32>
  %c1_331 = arith.constant 1 : index
  %dim_332 = tensor.dim %cast_56, %c1_331 : tensor<?x?xi32>
  %c0_333 = arith.constant 0 : index
  %dim_334 = tensor.dim %cast_56, %c0_333 : tensor<?x?xi32>
  %c1_335 = arith.constant 1 : index
  %dim_336 = tensor.dim %cast_56, %c1_335 : tensor<?x?xi32>
  %c0_337 = arith.constant 0 : index
  %dim_338 = tensor.dim %cast_56, %c0_337 : tensor<?x?xi32>
  %c1_339 = arith.constant 1 : index
  %dim_340 = tensor.dim %cast_56, %c1_339 : tensor<?x?xi32>
  %c0_341 = arith.constant 0 : index
  %dim_342 = tensor.dim %cast_56, %c0_341 : tensor<?x?xi32>
  %c1_343 = arith.constant 1 : index
  %dim_344 = tensor.dim %cast_56, %c1_343 : tensor<?x?xi32>
  %c0_345 = arith.constant 0 : index
  %dim_346 = tensor.dim %cast_56, %c0_345 : tensor<?x?xi32>
  %c1_347 = arith.constant 1 : index
  %dim_348 = tensor.dim %cast_56, %c1_347 : tensor<?x?xi32>
  %c0_349 = arith.constant 0 : index
  %dim_350 = tensor.dim %cast_56, %c0_349 : tensor<?x?xi32>
  %c1_351 = arith.constant 1 : index
  %dim_352 = tensor.dim %cast_56, %c1_351 : tensor<?x?xi32>
  %c0_353 = arith.constant 0 : index
  %dim_354 = tensor.dim %cast_56, %c0_353 : tensor<?x?xi32>
  %c1_355 = arith.constant 1 : index
  %dim_356 = tensor.dim %cast_56, %c1_355 : tensor<?x?xi32>
  %c0_357 = arith.constant 0 : index
  %dim_358 = tensor.dim %cast_56, %c0_357 : tensor<?x?xi32>
  %c1_359 = arith.constant 1 : index
  %dim_360 = tensor.dim %cast_56, %c1_359 : tensor<?x?xi32>
  %c0_361 = arith.constant 0 : index
  %dim_362 = tensor.dim %cast_56, %c0_361 : tensor<?x?xi32>
  %c1_363 = arith.constant 1 : index
  %dim_364 = tensor.dim %cast_56, %c1_363 : tensor<?x?xi32>
  %c0_365 = arith.constant 0 : index
  %dim_366 = tensor.dim %cast_56, %c0_365 : tensor<?x?xi32>
  %c1_367 = arith.constant 1 : index
  %dim_368 = tensor.dim %cast_56, %c1_367 : tensor<?x?xi32>
  %c0_369 = arith.constant 0 : index
  %dim_370 = tensor.dim %cast_56, %c0_369 : tensor<?x?xi32>
  %c1_371 = arith.constant 1 : index
  %dim_372 = tensor.dim %cast_56, %c1_371 : tensor<?x?xi32>
  %c0_373 = arith.constant 0 : index
  %dim_374 = tensor.dim %cast_56, %c0_373 : tensor<?x?xi32>
  %c1_375 = arith.constant 1 : index
  %dim_376 = tensor.dim %cast_56, %c1_375 : tensor<?x?xi32>
  %c0_377 = arith.constant 0 : index
  %dim_378 = tensor.dim %cast_56, %c0_377 : tensor<?x?xi32>
  %c1_379 = arith.constant 1 : index
  %dim_380 = tensor.dim %cast_56, %c1_379 : tensor<?x?xi32>
  %c0_381 = arith.constant 0 : index
  %dim_382 = tensor.dim %cast_56, %c0_381 : tensor<?x?xi32>
  %c1_383 = arith.constant 1 : index
  %dim_384 = tensor.dim %cast_56, %c1_383 : tensor<?x?xi32>
  %c0_385 = arith.constant 0 : index
  %dim_386 = tensor.dim %cast_56, %c0_385 : tensor<?x?xi32>
  %c1_387 = arith.constant 1 : index
  %dim_388 = tensor.dim %cast_56, %c1_387 : tensor<?x?xi32>
  %c0_389 = arith.constant 0 : index
  %dim_390 = tensor.dim %cast_56, %c0_389 : tensor<?x?xi32>
  %c1_391 = arith.constant 1 : index
  %dim_392 = tensor.dim %cast_56, %c1_391 : tensor<?x?xi32>
  %c0_393 = arith.constant 0 : index
  %dim_394 = tensor.dim %cast_56, %c0_393 : tensor<?x?xi32>
  %c1_395 = arith.constant 1 : index
  %dim_396 = tensor.dim %cast_56, %c1_395 : tensor<?x?xi32>
  %c0_397 = arith.constant 0 : index
  %dim_398 = tensor.dim %cast_56, %c0_397 : tensor<?x?xi32>
  %c1_399 = arith.constant 1 : index
  %dim_400 = tensor.dim %cast_56, %c1_399 : tensor<?x?xi32>
  %c0_401 = arith.constant 0 : index
  %dim_402 = tensor.dim %cast_56, %c0_401 : tensor<?x?xi32>
  %c1_403 = arith.constant 1 : index
  %dim_404 = tensor.dim %cast_56, %c1_403 : tensor<?x?xi32>
  %c0_405 = arith.constant 0 : index
  %dim_406 = tensor.dim %cast_56, %c0_405 : tensor<?x?xi32>
  %c1_407 = arith.constant 1 : index
  %dim_408 = tensor.dim %cast_56, %c1_407 : tensor<?x?xi32>
  %c0_409 = arith.constant 0 : index
  %dim_410 = tensor.dim %cast_56, %c0_409 : tensor<?x?xi32>
  %c1_411 = arith.constant 1 : index
  %dim_412 = tensor.dim %cast_56, %c1_411 : tensor<?x?xi32>
  %c0_413 = arith.constant 0 : index
  %dim_414 = tensor.dim %cast_56, %c0_413 : tensor<?x?xi32>
  %c1_415 = arith.constant 1 : index
  %dim_416 = tensor.dim %cast_56, %c1_415 : tensor<?x?xi32>
  %c0_417 = arith.constant 0 : index
  %dim_418 = tensor.dim %cast_56, %c0_417 : tensor<?x?xi32>
  %c1_419 = arith.constant 1 : index
  %dim_420 = tensor.dim %cast_56, %c1_419 : tensor<?x?xi32>
  %c0_421 = arith.constant 0 : index
  %dim_422 = tensor.dim %cast_56, %c0_421 : tensor<?x?xi32>
  %c1_423 = arith.constant 1 : index
  %dim_424 = tensor.dim %cast_56, %c1_423 : tensor<?x?xi32>
  %c0_425 = arith.constant 0 : index
  %dim_426 = tensor.dim %cast_56, %c0_425 : tensor<?x?xi32>
  %c1_427 = arith.constant 1 : index
  %dim_428 = tensor.dim %cast_56, %c1_427 : tensor<?x?xi32>
  %c0_429 = arith.constant 0 : index
  %dim_430 = tensor.dim %cast_56, %c0_429 : tensor<?x?xi32>
  %c1_431 = arith.constant 1 : index
  %dim_432 = tensor.dim %cast_56, %c1_431 : tensor<?x?xi32>
  %c0_433 = arith.constant 0 : index
  %dim_434 = tensor.dim %cast_56, %c0_433 : tensor<?x?xi32>
  %c1_435 = arith.constant 1 : index
  %dim_436 = tensor.dim %cast_56, %c1_435 : tensor<?x?xi32>
  %c0_437 = arith.constant 0 : index
  %dim_438 = tensor.dim %cast_56, %c0_437 : tensor<?x?xi32>
  %c1_439 = arith.constant 1 : index
  %dim_440 = tensor.dim %cast_56, %c1_439 : tensor<?x?xi32>
  %c0_441 = arith.constant 0 : index
  %dim_442 = tensor.dim %cast_56, %c0_441 : tensor<?x?xi32>
  %c1_443 = arith.constant 1 : index
  %dim_444 = tensor.dim %cast_56, %c1_443 : tensor<?x?xi32>
  %c0_445 = arith.constant 0 : index
  %dim_446 = tensor.dim %cast_56, %c0_445 : tensor<?x?xi32>
  %c1_447 = arith.constant 1 : index
  %dim_448 = tensor.dim %cast_56, %c1_447 : tensor<?x?xi32>
  %c0_449 = arith.constant 0 : index
  %dim_450 = tensor.dim %cast_56, %c0_449 : tensor<?x?xi32>
  %c1_451 = arith.constant 1 : index
  %dim_452 = tensor.dim %cast_56, %c1_451 : tensor<?x?xi32>
  %c0_453 = arith.constant 0 : index
  %dim_454 = tensor.dim %cast_56, %c0_453 : tensor<?x?xi32>
  %c1_455 = arith.constant 1 : index
  %dim_456 = tensor.dim %cast_56, %c1_455 : tensor<?x?xi32>
  %c0_457 = arith.constant 0 : index
  %dim_458 = tensor.dim %cast_56, %c0_457 : tensor<?x?xi32>
  %c1_459 = arith.constant 1 : index
  %dim_460 = tensor.dim %cast_56, %c1_459 : tensor<?x?xi32>
  %c0_461 = arith.constant 0 : index
  %dim_462 = tensor.dim %cast_56, %c0_461 : tensor<?x?xi32>
  %c1_463 = arith.constant 1 : index
  %dim_464 = tensor.dim %cast_56, %c1_463 : tensor<?x?xi32>
  %c0_465 = arith.constant 0 : index
  %dim_466 = tensor.dim %cast_56, %c0_465 : tensor<?x?xi32>
  %c1_467 = arith.constant 1 : index
  %dim_468 = tensor.dim %cast_56, %c1_467 : tensor<?x?xi32>
  %c0_469 = arith.constant 0 : index
  %dim_470 = tensor.dim %cast_56, %c0_469 : tensor<?x?xi32>
  %c1_471 = arith.constant 1 : index
  %dim_472 = tensor.dim %cast_56, %c1_471 : tensor<?x?xi32>
  %c0_473 = arith.constant 0 : index
  %dim_474 = tensor.dim %cast_56, %c0_473 : tensor<?x?xi32>
  %c1_475 = arith.constant 1 : index
  %dim_476 = tensor.dim %cast_56, %c1_475 : tensor<?x?xi32>
  %c0_477 = arith.constant 0 : index
  %dim_478 = tensor.dim %cast_56, %c0_477 : tensor<?x?xi32>
  %c1_479 = arith.constant 1 : index
  %dim_480 = tensor.dim %cast_56, %c1_479 : tensor<?x?xi32>
  %c0_481 = arith.constant 0 : index
  %dim_482 = tensor.dim %cast_56, %c0_481 : tensor<?x?xi32>
  %c1_483 = arith.constant 1 : index
  %dim_484 = tensor.dim %cast_56, %c1_483 : tensor<?x?xi32>
  %c0_485 = arith.constant 0 : index
  %dim_486 = tensor.dim %cast_56, %c0_485 : tensor<?x?xi32>
  %c1_487 = arith.constant 1 : index
  %dim_488 = tensor.dim %cast_56, %c1_487 : tensor<?x?xi32>
  %c0_489 = arith.constant 0 : index
  %dim_490 = tensor.dim %cast_56, %c0_489 : tensor<?x?xi32>
  %c1_491 = arith.constant 1 : index
  %dim_492 = tensor.dim %cast_56, %c1_491 : tensor<?x?xi32>
  %c0_493 = arith.constant 0 : index
  %dim_494 = tensor.dim %cast_56, %c0_493 : tensor<?x?xi32>
  %c1_495 = arith.constant 1 : index
  %dim_496 = tensor.dim %cast_56, %c1_495 : tensor<?x?xi32>
  %c0_497 = arith.constant 0 : index
  %dim_498 = tensor.dim %cast_56, %c0_497 : tensor<?x?xi32>
  %c1_499 = arith.constant 1 : index
  %dim_500 = tensor.dim %cast_56, %c1_499 : tensor<?x?xi32>
  %c0_501 = arith.constant 0 : index
  %dim_502 = tensor.dim %cast_56, %c0_501 : tensor<?x?xi32>
  %c1_503 = arith.constant 1 : index
  %dim_504 = tensor.dim %cast_56, %c1_503 : tensor<?x?xi32>
  %c0_505 = arith.constant 0 : index
  %dim_506 = tensor.dim %cast_56, %c0_505 : tensor<?x?xi32>
  %c1_507 = arith.constant 1 : index
  %dim_508 = tensor.dim %cast_56, %c1_507 : tensor<?x?xi32>
  %c0_509 = arith.constant 0 : index
  %dim_510 = tensor.dim %cast_56, %c0_509 : tensor<?x?xi32>
  %c1_511 = arith.constant 1 : index
  %dim_512 = tensor.dim %cast_56, %c1_511 : tensor<?x?xi32>
  %c0_513 = arith.constant 0 : index
  %dim_514 = tensor.dim %cast_56, %c0_513 : tensor<?x?xi32>
  %c1_515 = arith.constant 1 : index
  %dim_516 = tensor.dim %cast_56, %c1_515 : tensor<?x?xi32>
  %c0_517 = arith.constant 0 : index
  %dim_518 = tensor.dim %cast_56, %c0_517 : tensor<?x?xi32>
  %c1_519 = arith.constant 1 : index
  %dim_520 = tensor.dim %cast_56, %c1_519 : tensor<?x?xi32>
  %c0_521 = arith.constant 0 : index
  %dim_522 = tensor.dim %cast_56, %c0_521 : tensor<?x?xi32>
  %c1_523 = arith.constant 1 : index
  %dim_524 = tensor.dim %cast_56, %c1_523 : tensor<?x?xi32>
  %c0_525 = arith.constant 0 : index
  %dim_526 = tensor.dim %cast_56, %c0_525 : tensor<?x?xi32>
  %c1_527 = arith.constant 1 : index
  %dim_528 = tensor.dim %cast_56, %c1_527 : tensor<?x?xi32>
  %c0_529 = arith.constant 0 : index
  %dim_530 = tensor.dim %cast_56, %c0_529 : tensor<?x?xi32>
  %c1_531 = arith.constant 1 : index
  %dim_532 = tensor.dim %cast_56, %c1_531 : tensor<?x?xi32>
  %c0_533 = arith.constant 0 : index
  %dim_534 = tensor.dim %cast_56, %c0_533 : tensor<?x?xi32>
  %c1_535 = arith.constant 1 : index
  %dim_536 = tensor.dim %cast_56, %c1_535 : tensor<?x?xi32>
  %c0_537 = arith.constant 0 : index
  %dim_538 = tensor.dim %cast_56, %c0_537 : tensor<?x?xi32>
  %c1_539 = arith.constant 1 : index
  %dim_540 = tensor.dim %cast_56, %c1_539 : tensor<?x?xi32>
  %c0_541 = arith.constant 0 : index
  %dim_542 = tensor.dim %cast_56, %c0_541 : tensor<?x?xi32>
  %c1_543 = arith.constant 1 : index
  %dim_544 = tensor.dim %cast_56, %c1_543 : tensor<?x?xi32>
  %c0_545 = arith.constant 0 : index
  %dim_546 = tensor.dim %cast_56, %c0_545 : tensor<?x?xi32>
  %c1_547 = arith.constant 1 : index
  %dim_548 = tensor.dim %cast_56, %c1_547 : tensor<?x?xi32>
  %c0_549 = arith.constant 0 : index
  %dim_550 = tensor.dim %cast_56, %c0_549 : tensor<?x?xi32>
  %c1_551 = arith.constant 1 : index
  %dim_552 = tensor.dim %cast_56, %c1_551 : tensor<?x?xi32>
  %c0_553 = arith.constant 0 : index
  %dim_554 = tensor.dim %cast_56, %c0_553 : tensor<?x?xi32>
  %c1_555 = arith.constant 1 : index
  %dim_556 = tensor.dim %cast_56, %c1_555 : tensor<?x?xi32>
  %c0_557 = arith.constant 0 : index
  %dim_558 = tensor.dim %cast_56, %c0_557 : tensor<?x?xi32>
  %c1_559 = arith.constant 1 : index
  %dim_560 = tensor.dim %cast_56, %c1_559 : tensor<?x?xi32>
  %c0_561 = arith.constant 0 : index
  %dim_562 = tensor.dim %cast_56, %c0_561 : tensor<?x?xi32>
  %c1_563 = arith.constant 1 : index
  %dim_564 = tensor.dim %cast_56, %c1_563 : tensor<?x?xi32>
  %c0_565 = arith.constant 0 : index
  %dim_566 = tensor.dim %cast_56, %c0_565 : tensor<?x?xi32>
  %c1_567 = arith.constant 1 : index
  %dim_568 = tensor.dim %cast_56, %c1_567 : tensor<?x?xi32>
  %c0_569 = arith.constant 0 : index
  %dim_570 = tensor.dim %cast_56, %c0_569 : tensor<?x?xi32>
  %c1_571 = arith.constant 1 : index
  %dim_572 = tensor.dim %cast_56, %c1_571 : tensor<?x?xi32>
  %c0_573 = arith.constant 0 : index
  %dim_574 = tensor.dim %cast_56, %c0_573 : tensor<?x?xi32>
  %c1_575 = arith.constant 1 : index
  %dim_576 = tensor.dim %cast_56, %c1_575 : tensor<?x?xi32>
  %c0_577 = arith.constant 0 : index
  %dim_578 = tensor.dim %cast_56, %c0_577 : tensor<?x?xi32>
  %c1_579 = arith.constant 1 : index
  %dim_580 = tensor.dim %cast_56, %c1_579 : tensor<?x?xi32>
  %c0_581 = arith.constant 0 : index
  %dim_582 = tensor.dim %cast_56, %c0_581 : tensor<?x?xi32>
  %c1_583 = arith.constant 1 : index
  %dim_584 = tensor.dim %cast_56, %c1_583 : tensor<?x?xi32>
  %c0_585 = arith.constant 0 : index
  %dim_586 = tensor.dim %cast_56, %c0_585 : tensor<?x?xi32>
  %c1_587 = arith.constant 1 : index
  %dim_588 = tensor.dim %cast_56, %c1_587 : tensor<?x?xi32>
  %c0_589 = arith.constant 0 : index
  %dim_590 = tensor.dim %cast_56, %c0_589 : tensor<?x?xi32>
  %c1_591 = arith.constant 1 : index
  %dim_592 = tensor.dim %cast_56, %c1_591 : tensor<?x?xi32>
  %c0_593 = arith.constant 0 : index
  %dim_594 = tensor.dim %cast_56, %c0_593 : tensor<?x?xi32>
  %c1_595 = arith.constant 1 : index
  %dim_596 = tensor.dim %cast_56, %c1_595 : tensor<?x?xi32>
  %c0_597 = arith.constant 0 : index
  %dim_598 = tensor.dim %cast_56, %c0_597 : tensor<?x?xi32>
  %c1_599 = arith.constant 1 : index
  %dim_600 = tensor.dim %cast_56, %c1_599 : tensor<?x?xi32>
  %c0_601 = arith.constant 0 : index
  %dim_602 = tensor.dim %cast_56, %c0_601 : tensor<?x?xi32>
  %c1_603 = arith.constant 1 : index
  %dim_604 = tensor.dim %cast_56, %c1_603 : tensor<?x?xi32>
  %c0_605 = arith.constant 0 : index
  %dim_606 = tensor.dim %cast_56, %c0_605 : tensor<?x?xi32>
  %c1_607 = arith.constant 1 : index
  %dim_608 = tensor.dim %cast_56, %c1_607 : tensor<?x?xi32>
  %c0_609 = arith.constant 0 : index
  %dim_610 = tensor.dim %cast_56, %c0_609 : tensor<?x?xi32>
  %c1_611 = arith.constant 1 : index
  %dim_612 = tensor.dim %cast_56, %c1_611 : tensor<?x?xi32>
  %c0_613 = arith.constant 0 : index
  %dim_614 = tensor.dim %cast_56, %c0_613 : tensor<?x?xi32>
  %c1_615 = arith.constant 1 : index
  %dim_616 = tensor.dim %cast_56, %c1_615 : tensor<?x?xi32>
  %c0_617 = arith.constant 0 : index
  %dim_618 = tensor.dim %cast_56, %c0_617 : tensor<?x?xi32>
  %c1_619 = arith.constant 1 : index
  %dim_620 = tensor.dim %cast_56, %c1_619 : tensor<?x?xi32>
  %c0_621 = arith.constant 0 : index
  %dim_622 = tensor.dim %cast_56, %c0_621 : tensor<?x?xi32>
  %c1_623 = arith.constant 1 : index
  %dim_624 = tensor.dim %cast_56, %c1_623 : tensor<?x?xi32>
  %c0_625 = arith.constant 0 : index
  %dim_626 = tensor.dim %cast_56, %c0_625 : tensor<?x?xi32>
  %c1_627 = arith.constant 1 : index
  %dim_628 = tensor.dim %cast_56, %c1_627 : tensor<?x?xi32>
  %c0_629 = arith.constant 0 : index
  %dim_630 = tensor.dim %cast_56, %c0_629 : tensor<?x?xi32>
  %c1_631 = arith.constant 1 : index
  %dim_632 = tensor.dim %cast_56, %c1_631 : tensor<?x?xi32>
  %c0_633 = arith.constant 0 : index
  %dim_634 = tensor.dim %cast_56, %c0_633 : tensor<?x?xi32>
  %c1_635 = arith.constant 1 : index
  %dim_636 = tensor.dim %cast_56, %c1_635 : tensor<?x?xi32>
  %c0_637 = arith.constant 0 : index
  %dim_638 = tensor.dim %cast_56, %c0_637 : tensor<?x?xi32>
  %c1_639 = arith.constant 1 : index
  %dim_640 = tensor.dim %cast_56, %c1_639 : tensor<?x?xi32>
  %c0_641 = arith.constant 0 : index
  %dim_642 = tensor.dim %cast_56, %c0_641 : tensor<?x?xi32>
  %c1_643 = arith.constant 1 : index
  %dim_644 = tensor.dim %cast_56, %c1_643 : tensor<?x?xi32>
  %c0_645 = arith.constant 0 : index
  %dim_646 = tensor.dim %cast_56, %c0_645 : tensor<?x?xi32>
  %c1_647 = arith.constant 1 : index
  %dim_648 = tensor.dim %cast_56, %c1_647 : tensor<?x?xi32>
  %c0_649 = arith.constant 0 : index
  %dim_650 = tensor.dim %cast_56, %c0_649 : tensor<?x?xi32>
  %c1_651 = arith.constant 1 : index
  %dim_652 = tensor.dim %cast_56, %c1_651 : tensor<?x?xi32>
  %c0_653 = arith.constant 0 : index
  %dim_654 = tensor.dim %cast_56, %c0_653 : tensor<?x?xi32>
  %c1_655 = arith.constant 1 : index
  %dim_656 = tensor.dim %cast_56, %c1_655 : tensor<?x?xi32>
  %c0_657 = arith.constant 0 : index
  %dim_658 = tensor.dim %cast_56, %c0_657 : tensor<?x?xi32>
  %c1_659 = arith.constant 1 : index
  %dim_660 = tensor.dim %cast_56, %c1_659 : tensor<?x?xi32>
  %c0_661 = arith.constant 0 : index
  %dim_662 = tensor.dim %cast_56, %c0_661 : tensor<?x?xi32>
  %c1_663 = arith.constant 1 : index
  %dim_664 = tensor.dim %cast_56, %c1_663 : tensor<?x?xi32>
  %c0_665 = arith.constant 0 : index
  %dim_666 = tensor.dim %cast_56, %c0_665 : tensor<?x?xi32>
  %c1_667 = arith.constant 1 : index
  %dim_668 = tensor.dim %cast_56, %c1_667 : tensor<?x?xi32>
  %c0_669 = arith.constant 0 : index
  %dim_670 = tensor.dim %cast_56, %c0_669 : tensor<?x?xi32>
  %c1_671 = arith.constant 1 : index
  %dim_672 = tensor.dim %cast_56, %c1_671 : tensor<?x?xi32>
  %c0_673 = arith.constant 0 : index
  %dim_674 = tensor.dim %cast_56, %c0_673 : tensor<?x?xi32>
  %c1_675 = arith.constant 1 : index
  %dim_676 = tensor.dim %cast_56, %c1_675 : tensor<?x?xi32>
  %c0_677 = arith.constant 0 : index
  %dim_678 = tensor.dim %cast_56, %c0_677 : tensor<?x?xi32>
  %c1_679 = arith.constant 1 : index
  %dim_680 = tensor.dim %cast_56, %c1_679 : tensor<?x?xi32>
  %c0_681 = arith.constant 0 : index
  %dim_682 = tensor.dim %cast_56, %c0_681 : tensor<?x?xi32>
  %c1_683 = arith.constant 1 : index
  %dim_684 = tensor.dim %cast_56, %c1_683 : tensor<?x?xi32>
  %c0_685 = arith.constant 0 : index
  %dim_686 = tensor.dim %cast_56, %c0_685 : tensor<?x?xi32>
  %c1_687 = arith.constant 1 : index
  %dim_688 = tensor.dim %cast_56, %c1_687 : tensor<?x?xi32>
  %c0_689 = arith.constant 0 : index
  %dim_690 = tensor.dim %cast_56, %c0_689 : tensor<?x?xi32>
  %c1_691 = arith.constant 1 : index
  %dim_692 = tensor.dim %cast_56, %c1_691 : tensor<?x?xi32>
  %c0_693 = arith.constant 0 : index
  %dim_694 = tensor.dim %cast_56, %c0_693 : tensor<?x?xi32>
  %c1_695 = arith.constant 1 : index
  %dim_696 = tensor.dim %cast_56, %c1_695 : tensor<?x?xi32>
  %c0_697 = arith.constant 0 : index
  %c1_698 = arith.constant 1 : index
  %c0_699 = arith.constant 0 : index
  %c1_700 = arith.constant 1 : index
  %c0_701 = arith.constant 0 : index
  %c1_702 = arith.constant 1 : index
  %c0_703 = arith.constant 0 : index
  %c1_704 = arith.constant 1 : index
  %c0_705 = arith.constant 0 : index
  %c1_706 = arith.constant 1 : index
  %c0_707 = arith.constant 0 : index
  %c1_708 = arith.constant 1 : index
  %c0_709 = arith.constant 0 : index
  %c1_710 = arith.constant 1 : index
  %c0_711 = arith.constant 0 : index
  %c1_712 = arith.constant 1 : index
  %c0_713 = arith.constant 0 : index
  %c1_714 = arith.constant 1 : index
  %c0_715 = arith.constant 0 : index
  %c1_716 = arith.constant 1 : index
  %c0_717 = arith.constant 0 : index
  %c1_718 = arith.constant 1 : index
  %c0_719 = arith.constant 0 : index
  %c1_720 = arith.constant 1 : index
  %c0_721 = arith.constant 0 : index
  %c1_722 = arith.constant 1 : index
  %c0_723 = arith.constant 0 : index
  %c1_724 = arith.constant 1 : index
  %c0_725 = arith.constant 0 : index
  %c1_726 = arith.constant 1 : index
  %c0_727 = arith.constant 0 : index
  %c1_728 = arith.constant 1 : index
  %c0_729 = arith.constant 0 : index
  %c1_730 = arith.constant 1 : index
  %c0_731 = arith.constant 0 : index
  %c1_732 = arith.constant 1 : index
  %c0_733 = arith.constant 0 : index
  %c1_734 = arith.constant 1 : index
  %c0_735 = arith.constant 0 : index
  %c1_736 = arith.constant 1 : index
  %c0_737 = arith.constant 0 : index
  %c1_738 = arith.constant 1 : index
  %c0_739 = arith.constant 0 : index
  %c1_740 = arith.constant 1 : index
  %c0_741 = arith.constant 0 : index
  %c1_742 = arith.constant 1 : index
  %c0_743 = arith.constant 0 : index
  %c1_744 = arith.constant 1 : index
  %c0_745 = arith.constant 0 : index
  %c1_746 = arith.constant 1 : index
  %c0_747 = arith.constant 0 : index
  %c1_748 = arith.constant 1 : index
  %c0_749 = arith.constant 0 : index
  %c1_750 = arith.constant 1 : index
  %c0_751 = arith.constant 0 : index
  %c1_752 = arith.constant 1 : index
  %c0_753 = arith.constant 0 : index
  %c1_754 = arith.constant 1 : index
  %c0_755 = arith.constant 0 : index
  %c1_756 = arith.constant 1 : index
  %c0_757 = arith.constant 0 : index
  %c1_758 = arith.constant 1 : index
  %c0_759 = arith.constant 0 : index
  %c1_760 = arith.constant 1 : index
  %c0_761 = arith.constant 0 : index
  %c1_762 = arith.constant 1 : index
  %c0_763 = arith.constant 0 : index
  %c1_764 = arith.constant 1 : index
  %c0_765 = arith.constant 0 : index
  %c1_766 = arith.constant 1 : index
  %c0_767 = arith.constant 0 : index
  %c1_768 = arith.constant 1 : index
  %c0_769 = arith.constant 0 : index
  %c1_770 = arith.constant 1 : index
  %c0_771 = arith.constant 0 : index
  %c1_772 = arith.constant 1 : index
  %c0_773 = arith.constant 0 : index
  %c1_774 = arith.constant 1 : index
  %c0_775 = arith.constant 0 : index
  %c1_776 = arith.constant 1 : index
  %c0_777 = arith.constant 0 : index
  %c1_778 = arith.constant 1 : index
  %c0_779 = arith.constant 0 : index
  %c1_780 = arith.constant 1 : index
  %c0_781 = arith.constant 0 : index
  %c1_782 = arith.constant 1 : index
  %c0_783 = arith.constant 0 : index
  %c1_784 = arith.constant 1 : index
  %c0_785 = arith.constant 0 : index
  %c1_786 = arith.constant 1 : index
  %c0_787 = arith.constant 0 : index
  %c1_788 = arith.constant 1 : index
  %c0_789 = arith.constant 0 : index
  %c1_790 = arith.constant 1 : index
  %c0_791 = arith.constant 0 : index
  %c1_792 = arith.constant 1 : index
  %c0_793 = arith.constant 0 : index
  %c1_794 = arith.constant 1 : index
  %c0_795 = arith.constant 0 : index
  %c1_796 = arith.constant 1 : index
  %c0_797 = arith.constant 0 : index
  %c1_798 = arith.constant 1 : index
  %c0_799 = arith.constant 0 : index
  %c1_800 = arith.constant 1 : index
  %c0_801 = arith.constant 0 : index
  %c1_802 = arith.constant 1 : index
  %c0_803 = arith.constant 0 : index
  %c1_804 = arith.constant 1 : index
  %c0_805 = arith.constant 0 : index
  %c1_806 = arith.constant 1 : index
  %c0_807 = arith.constant 0 : index
  %c1_808 = arith.constant 1 : index
  %c0_809 = arith.constant 0 : index
  %c1_810 = arith.constant 1 : index
  %c0_811 = arith.constant 0 : index
  %c1_812 = arith.constant 1 : index
  %c0_813 = arith.constant 0 : index
  %c1_814 = arith.constant 1 : index
  %c0_815 = arith.constant 0 : index
  %c1_816 = arith.constant 1 : index
  %c0_817 = arith.constant 0 : index
  %c1_818 = arith.constant 1 : index
  %c0_819 = arith.constant 0 : index
  %c1_820 = arith.constant 1 : index
  %c0_821 = arith.constant 0 : index
  %c1_822 = arith.constant 1 : index
  %c0_823 = arith.constant 0 : index
  %c1_824 = arith.constant 1 : index
  %c0_825 = arith.constant 0 : index
  %c1_826 = arith.constant 1 : index
  %c0_827 = arith.constant 0 : index
  %c1_828 = arith.constant 1 : index
  %c0_829 = arith.constant 0 : index
  %c1_830 = arith.constant 1 : index
  %c0_831 = arith.constant 0 : index
  %c1_832 = arith.constant 1 : index
  %c0_833 = arith.constant 0 : index
  %c1_834 = arith.constant 1 : index
  %c0_835 = arith.constant 0 : index
  %c1_836 = arith.constant 1 : index
  %c0_837 = arith.constant 0 : index
  %c1_838 = arith.constant 1 : index
  %c0_839 = arith.constant 0 : index
  %c1_840 = arith.constant 1 : index
  %c0_841 = arith.constant 0 : index
  %c1_842 = arith.constant 1 : index
  %c0_843 = arith.constant 0 : index
  %c1_844 = arith.constant 1 : index
  %c0_845 = arith.constant 0 : index
  %c1_846 = arith.constant 1 : index
  %c0_847 = arith.constant 0 : index
  %c1_848 = arith.constant 1 : index
  %c0_849 = arith.constant 0 : index
  %c1_850 = arith.constant 1 : index
  %c0_851 = arith.constant 0 : index
  %c1_852 = arith.constant 1 : index
  %c0_853 = arith.constant 0 : index
  %c1_854 = arith.constant 1 : index
  %c0_855 = arith.constant 0 : index
  %c1_856 = arith.constant 1 : index
  %c0_857 = arith.constant 0 : index
  %c1_858 = arith.constant 1 : index
  %c0_859 = arith.constant 0 : index
  %c1_860 = arith.constant 1 : index
  %c0_861 = arith.constant 0 : index
  %c1_862 = arith.constant 1 : index
  %c0_863 = arith.constant 0 : index
  %c1_864 = arith.constant 1 : index
  %c0_865 = arith.constant 0 : index
  %c1_866 = arith.constant 1 : index
  %c0_867 = arith.constant 0 : index
  %c1_868 = arith.constant 1 : index
  %c0_869 = arith.constant 0 : index
  %c1_870 = arith.constant 1 : index
  %c0_871 = arith.constant 0 : index
  %c1_872 = arith.constant 1 : index
  %c0_873 = arith.constant 0 : index
  %c1_874 = arith.constant 1 : index
  %c0_875 = arith.constant 0 : index
  %c1_876 = arith.constant 1 : index
  %c0_877 = arith.constant 0 : index
  %c1_878 = arith.constant 1 : index
  %c0_879 = arith.constant 0 : index
  %c1_880 = arith.constant 1 : index
  %c0_881 = arith.constant 0 : index
  %c1_882 = arith.constant 1 : index
  %c0_883 = arith.constant 0 : index
  %c1_884 = arith.constant 1 : index
  %c0_885 = arith.constant 0 : index
  %c1_886 = arith.constant 1 : index
  %c0_887 = arith.constant 0 : index
  %c1_888 = arith.constant 1 : index
  %c0_889 = arith.constant 0 : index
  %c1_890 = arith.constant 1 : index
  %c0_891 = arith.constant 0 : index
  %c1_892 = arith.constant 1 : index
  %c0_893 = arith.constant 0 : index
  %c1_894 = arith.constant 1 : index
  %c0_895 = arith.constant 0 : index
  %c1_896 = arith.constant 1 : index
  %c0_897 = arith.constant 0 : index
  %c1_898 = arith.constant 1 : index
  %c0_899 = arith.constant 0 : index
  %c1_900 = arith.constant 1 : index
  %c0_901 = arith.constant 0 : index
  %c1_902 = arith.constant 1 : index
  %c0_903 = arith.constant 0 : index
  %c1_904 = arith.constant 1 : index
  %c0_905 = arith.constant 0 : index
  %c1_906 = arith.constant 1 : index
  %c0_907 = arith.constant 0 : index
  %c1_908 = arith.constant 1 : index
  %c0_909 = arith.constant 0 : index
  %c1_910 = arith.constant 1 : index
  %c0_911 = arith.constant 0 : index
  %c1_912 = arith.constant 1 : index
  %c0_913 = arith.constant 0 : index
  %c1_914 = arith.constant 1 : index
  %c0_915 = arith.constant 0 : index
  %c1_916 = arith.constant 1 : index
  %c0_917 = arith.constant 0 : index
  %c1_918 = arith.constant 1 : index
  %c0_919 = arith.constant 0 : index
  %c1_920 = arith.constant 1 : index
  %c0_921 = arith.constant 0 : index
  %c1_922 = arith.constant 1 : index
  %c0_923 = arith.constant 0 : index
  %c1_924 = arith.constant 1 : index
  %c0_925 = arith.constant 0 : index
  %c1_926 = arith.constant 1 : index
  %c0_927 = arith.constant 0 : index
  %c1_928 = arith.constant 1 : index
  %c0_929 = arith.constant 0 : index
  %c1_930 = arith.constant 1 : index
  %c0_931 = arith.constant 0 : index
  %c1_932 = arith.constant 1 : index
  %c0_933 = arith.constant 0 : index
  %c1_934 = arith.constant 1 : index
  %c0_935 = arith.constant 0 : index
  %c1_936 = arith.constant 1 : index
  %c0_937 = arith.constant 0 : index
  %c1_938 = arith.constant 1 : index
  %c0_939 = arith.constant 0 : index
  %c1_940 = arith.constant 1 : index
  %c0_941 = arith.constant 0 : index
  %c1_942 = arith.constant 1 : index
  %c0_943 = arith.constant 0 : index
  %c1_944 = arith.constant 1 : index
  %c0_945 = arith.constant 0 : index
  %c1_946 = arith.constant 1 : index
  %c0_947 = arith.constant 0 : index
  %c1_948 = arith.constant 1 : index
  %c0_949 = arith.constant 0 : index
  %c1_950 = arith.constant 1 : index
  %c0_951 = arith.constant 0 : index
  %c1_952 = arith.constant 1 : index
  %c0_953 = arith.constant 0 : index
  %c1_954 = arith.constant 1 : index
  %c0_955 = arith.constant 0 : index
  %c1_956 = arith.constant 1 : index
  %c0_957 = arith.constant 0 : index
  %c1_958 = arith.constant 1 : index
  %c0_959 = arith.constant 0 : index
  %c1_960 = arith.constant 1 : index
  %c0_961 = arith.constant 0 : index
  %c1_962 = arith.constant 1 : index
  %c0_963 = arith.constant 0 : index
  %c1_964 = arith.constant 1 : index
  %c0_965 = arith.constant 0 : index
  %c1_966 = arith.constant 1 : index
  %c0_967 = arith.constant 0 : index
  %c1_968 = arith.constant 1 : index
  %c0_969 = arith.constant 0 : index
  %c1_970 = arith.constant 1 : index
  %c0_971 = arith.constant 0 : index
  %c1_972 = arith.constant 1 : index
  %c0_973 = arith.constant 0 : index
  %c1_974 = arith.constant 1 : index
  %c0_975 = arith.constant 0 : index
  %c1_976 = arith.constant 1 : index
  %c0_977 = arith.constant 0 : index
  %c1_978 = arith.constant 1 : index
  %c0_979 = arith.constant 0 : index
  %c1_980 = arith.constant 1 : index
  %c0_981 = arith.constant 0 : index
  %c1_982 = arith.constant 1 : index
  %c0_983 = arith.constant 0 : index
  %c1_984 = arith.constant 1 : index
  %c0_985 = arith.constant 0 : index
  %c1_986 = arith.constant 1 : index
  %c0_987 = arith.constant 0 : index
  %c1_988 = arith.constant 1 : index
  %c0_989 = arith.constant 0 : index
  %c1_990 = arith.constant 1 : index
  %c0_991 = arith.constant 0 : index
  %c1_992 = arith.constant 1 : index
  %c0_993 = arith.constant 0 : index
  %c1_994 = arith.constant 1 : index
  %c0_995 = arith.constant 0 : index
  %c1_996 = arith.constant 1 : index
  %c0_997 = arith.constant 0 : index
  %c1_998 = arith.constant 1 : index
  %c0_999 = arith.constant 0 : index
  %c1_1000 = arith.constant 1 : index
  %c0_1001 = arith.constant 0 : index
  %c1_1002 = arith.constant 1 : index
  %c0_1003 = arith.constant 0 : index
  %c1_1004 = arith.constant 1 : index
  %c0_1005 = arith.constant 0 : index
  %c1_1006 = arith.constant 1 : index
  %c0_1007 = arith.constant 0 : index
  %c1_1008 = arith.constant 1 : index
  %c0_1009 = arith.constant 0 : index
  %c1_1010 = arith.constant 1 : index
  %c0_1011 = arith.constant 0 : index
  %c1_1012 = arith.constant 1 : index
  %9 = flow.dispatch.region -> (tensor<?x?xi32, #iree_linalg_ext.encoding<role =  RESULT, element_types = [i8, i8, i32], original_type = tensor<4x4xi32>, matmul_narrow_M = 4 : index, matmul_narrow_N = 4 : index, user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>>{%dim_274, %dim_480}) {
    %13 = iree_linalg_ext.set_encoding %cast_56 : tensor<?x?xi32> -> tensor<?x?xi32, #iree_linalg_ext.encoding<role =  RESULT, element_types = [i8, i8, i32], original_type = tensor<4x4xi32>, matmul_narrow_M = 4 : index, matmul_narrow_N = 4 : index, user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>>
    flow.return %13 : tensor<?x?xi32, #iree_linalg_ext.encoding<role =  RESULT, element_types = [i8, i8, i32], original_type = tensor<4x4xi32>, matmul_narrow_M = 4 : index, matmul_narrow_N = 4 : index, user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>>
  }
  %10 = flow.dispatch.region -> (tensor<?x?xi32, #iree_linalg_ext.encoding<role =  RESULT, element_types = [i8, i8, i32], original_type = tensor<4x4xi32>, matmul_narrow_M = 4 : index, matmul_narrow_N = 4 : index, user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>>{%dim_330, %dim_360}) {
    %13 = linalg.matmul ins(%5, %6 : tensor<?x?xi8, #iree_linalg_ext.encoding<role =  LHS, element_types = [i8, i8, i32], original_type = tensor<4x4xi8>, matmul_narrow_M = 4 : index, matmul_narrow_N = 4 : index, user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>>, tensor<?x?xi8, #iree_linalg_ext.encoding<role =  RHS, element_types = [i8, i8, i32], original_type = tensor<4x4xi8>, matmul_narrow_M = 4 : index, matmul_narrow_N = 4 : index, user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>>) outs(%9 : tensor<?x?xi32, #iree_linalg_ext.encoding<role =  RESULT, element_types = [i8, i8, i32], original_type = tensor<4x4xi32>, matmul_narrow_M = 4 : index, matmul_narrow_N = 4 : index, user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>>) -> tensor<?x?xi32, #iree_linalg_ext.encoding<role =  RESULT, element_types = [i8, i8, i32], original_type = tensor<4x4xi32>, matmul_narrow_M = 4 : index, matmul_narrow_N = 4 : index, user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>>
    flow.return %13 : tensor<?x?xi32, #iree_linalg_ext.encoding<role =  RESULT, element_types = [i8, i8, i32], original_type = tensor<4x4xi32>, matmul_narrow_M = 4 : index, matmul_narrow_N = 4 : index, user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>>
  }
  %11 = flow.dispatch.region -> (tensor<4x4xi32>) {
    %13 = iree_linalg_ext.unset_encoding %10 : tensor<?x?xi32, #iree_linalg_ext.encoding<role =  RESULT, element_types = [i8, i8, i32], original_type = tensor<4x4xi32>, matmul_narrow_M = 4 : index, matmul_narrow_N = 4 : index, user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>> -> tensor<?x?xi32>
    %extracted_slice = tensor.extract_slice %13[0, 0] [4, 4] [1, 1] : tensor<?x?xi32> to tensor<4x4xi32>
    flow.return %extracted_slice : tensor<4x4xi32>
  }
  %12 = hal.tensor.export %11 "output0" : tensor<4x4xi32> -> !hal.buffer_view
  util.return %12 : !hal.buffer_view
}
benvanik commented 7 months ago

I changed the LinalgExt getDimValue helper to createOrFold (as tensor dialect and others do) and it at least spits out reasonable IR:

// -----// IR Dump After FormDispatchRegions (iree-flow-form-dispatch-regions) //----- //
mlir-asm-printer: Verifying operation: util.func
util.func public @matmul_accumulate_4x4xi8_times_4x4xi8_into_4x4xi32(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub, iree.reflection = {iree.abi.declaration = "sync func @matmul_accumulate_4x4xi8_times_4x4xi8_into_4x4xi32(%input0: tensor<4x4xi8>, %input1: tensor<4x4xi8>, %input2: tensor<4x4xi32>) -> (%output0: tensor<4x4xi32>)"}} {
  %c0_i32 = arith.constant 0 : i32
  %c0_i8 = arith.constant 0 : i8
  %0 = hal.tensor.import %arg0 "input0" : !hal.buffer_view -> tensor<4x4xi8>
  %1 = hal.tensor.import %arg1 "input1" : !hal.buffer_view -> tensor<4x4xi8>
  %2 = hal.tensor.import %arg2 "input2" : !hal.buffer_view -> tensor<4x4xi32>
  %3 = tensor.empty() : tensor<16x16xi8>
  %4 = linalg.fill ins(%c0_i8 : i8) outs(%3 : tensor<16x16xi8>) -> tensor<16x16xi8>
  %inserted_slice = tensor.insert_slice %0 into %4[0, 0] [4, 4] [1, 1] : tensor<4x4xi8> into tensor<16x16xi8>
  %cast = tensor.cast %inserted_slice : tensor<16x16xi8> to tensor<?x?xi8>
  %c0 = arith.constant 0 : index
  %c16 = arith.constant 16 : index
  %c1 = arith.constant 1 : index
  %c16_0 = arith.constant 16 : index
  %5 = flow.dispatch.region -> (tensor<?x?xi8, #iree_linalg_ext.encoding<role =  LHS, element_types = [i8, i8, i32], original_type = tensor<4x4xi8>, matmul_narrow_M = 4 : index, matmul_narrow_N = 4 : index, user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>>{%c16, %c16_0}) {
    %13 = iree_linalg_ext.set_encoding %cast : tensor<?x?xi8> -> tensor<?x?xi8, #iree_linalg_ext.encoding<role =  LHS, element_types = [i8, i8, i32], original_type = tensor<4x4xi8>, matmul_narrow_M = 4 : index, matmul_narrow_N = 4 : index, user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>>
    flow.return %13 : tensor<?x?xi8, #iree_linalg_ext.encoding<role =  LHS, element_types = [i8, i8, i32], original_type = tensor<4x4xi8>, matmul_narrow_M = 4 : index, matmul_narrow_N = 4 : index, user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>>
  }
  %inserted_slice_1 = tensor.insert_slice %1 into %4[0, 0] [4, 4] [1, 1] : tensor<4x4xi8> into tensor<16x16xi8>
  %cast_2 = tensor.cast %inserted_slice_1 : tensor<16x16xi8> to tensor<?x?xi8>
  %c0_3 = arith.constant 0 : index
  %c16_4 = arith.constant 16 : index
  %c1_5 = arith.constant 1 : index
  %c16_6 = arith.constant 16 : index
  %6 = flow.dispatch.region -> (tensor<?x?xi8, #iree_linalg_ext.encoding<role =  RHS, element_types = [i8, i8, i32], original_type = tensor<4x4xi8>, matmul_narrow_M = 4 : index, matmul_narrow_N = 4 : index, user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>>{%c16_4, %c16_6}) {
    %13 = iree_linalg_ext.set_encoding %cast_2 : tensor<?x?xi8> -> tensor<?x?xi8, #iree_linalg_ext.encoding<role =  RHS, element_types = [i8, i8, i32], original_type = tensor<4x4xi8>, matmul_narrow_M = 4 : index, matmul_narrow_N = 4 : index, user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>>
    flow.return %13 : tensor<?x?xi8, #iree_linalg_ext.encoding<role =  RHS, element_types = [i8, i8, i32], original_type = tensor<4x4xi8>, matmul_narrow_M = 4 : index, matmul_narrow_N = 4 : index, user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>>
  }
  %7 = tensor.empty() : tensor<16x16xi32>
  %8 = linalg.fill ins(%c0_i32 : i32) outs(%7 : tensor<16x16xi32>) -> tensor<16x16xi32>
  %inserted_slice_7 = tensor.insert_slice %2 into %8[0, 0] [4, 4] [1, 1] : tensor<4x4xi32> into tensor<16x16xi32>
  %cast_8 = tensor.cast %inserted_slice_7 : tensor<16x16xi32> to tensor<?x?xi32>
  %c0_9 = arith.constant 0 : index
  %c16_10 = arith.constant 16 : index
  %c1_11 = arith.constant 1 : index
  %c16_12 = arith.constant 16 : index
  %c0_13 = arith.constant 0 : index
  %c16_14 = arith.constant 16 : index
  %c1_15 = arith.constant 1 : index
  %c16_16 = arith.constant 16 : index
  %9 = flow.dispatch.region -> (tensor<?x?xi32, #iree_linalg_ext.encoding<role =  RESULT, element_types = [i8, i8, i32], original_type = tensor<4x4xi32>, matmul_narrow_M = 4 : index, matmul_narrow_N = 4 : index, user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>>{%c16_14, %c16_16}) {
    %13 = iree_linalg_ext.set_encoding %cast_8 : tensor<?x?xi32> -> tensor<?x?xi32, #iree_linalg_ext.encoding<role =  RESULT, element_types = [i8, i8, i32], original_type = tensor<4x4xi32>, matmul_narrow_M = 4 : index, matmul_narrow_N = 4 : index, user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>>
    flow.return %13 : tensor<?x?xi32, #iree_linalg_ext.encoding<role =  RESULT, element_types = [i8, i8, i32], original_type = tensor<4x4xi32>, matmul_narrow_M = 4 : index, matmul_narrow_N = 4 : index, user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>>
  }
  %10 = flow.dispatch.region -> (tensor<?x?xi32, #iree_linalg_ext.encoding<role =  RESULT, element_types = [i8, i8, i32], original_type = tensor<4x4xi32>, matmul_narrow_M = 4 : index, matmul_narrow_N = 4 : index, user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>>{%c16_10, %c16_12}) {
    %13 = linalg.matmul ins(%5, %6 : tensor<?x?xi8, #iree_linalg_ext.encoding<role =  LHS, element_types = [i8, i8, i32], original_type = tensor<4x4xi8>, matmul_narrow_M = 4 : index, matmul_narrow_N = 4 : index, user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>>, tensor<?x?xi8, #iree_linalg_ext.encoding<role =  RHS, element_types = [i8, i8, i32], original_type = tensor<4x4xi8>, matmul_narrow_M = 4 : index, matmul_narrow_N = 4 : index, user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>>) outs(%9 : tensor<?x?xi32, #iree_linalg_ext.encoding<role =  RESULT, element_types = [i8, i8, i32], original_type = tensor<4x4xi32>, matmul_narrow_M = 4 : index, matmul_narrow_N = 4 : index, user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>>) -> tensor<?x?xi32, #iree_linalg_ext.encoding<role =  RESULT, element_types = [i8, i8, i32], original_type = tensor<4x4xi32>, matmul_narrow_M = 4 : index, matmul_narrow_N = 4 : index, user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>>
    flow.return %13 : tensor<?x?xi32, #iree_linalg_ext.encoding<role =  RESULT, element_types = [i8, i8, i32], original_type = tensor<4x4xi32>, matmul_narrow_M = 4 : index, matmul_narrow_N = 4 : index, user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>>
  }
  %11 = flow.dispatch.region -> (tensor<4x4xi32>) {
    %13 = iree_linalg_ext.unset_encoding %10 : tensor<?x?xi32, #iree_linalg_ext.encoding<role =  RESULT, element_types = [i8, i8, i32], original_type = tensor<4x4xi32>, matmul_narrow_M = 4 : index, matmul_narrow_N = 4 : index, user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>> -> tensor<?x?xi32>
    %extracted_slice = tensor.extract_slice %13[0, 0] [4, 4] [1, 1] : tensor<?x?xi32> to tensor<4x4xi32>
    flow.return %extracted_slice : tensor<4x4xi32>
  }
  %12 = hal.tensor.export %11 "output0" : tensor<4x4xi32> -> !hal.buffer_view
  util.return %12 : !hal.buffer_view
}

It's still creating the tensor.dim ops, folding them, and then erasing them as many times but the IR is better on the way out.

MaheshRavishankar commented 7 months ago

Ok, I thought this could me made better, but really needs a little bit of reworking to use a special listener that the reify methods can query to see if the op already exists and reuse it. Will take a bit of work upstream to plumb this all through. Should really move the Listener that is in FormDispatchRegions upstream and use it when available.

MaheshRavishankar commented 7 months ago

The only real way to fix this is to modify the interface method in ReifyRankedShapeTypeInterface. The reifyResultDim method should take an extra (optional) Listener object as argument. When provided, the implementation should use the Listener to query for existing tensor.dim operations instead of creating new ones.

This doesn't need to be a world rewrite. Can be added incrementallly.