Open benvanik opened 1 year ago
Diffed and found the regression. With --iree-opt-const-eval=false
the model has these globals:
#map = affine_map<(d0) -> (d0)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d1)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map3 = affine_map<(d0, d1) -> (d0, d1)>
#map4 = affine_map<(d0, d1) -> (d0)>
#map5 = affine_map<(d0, d1) -> (d1)>
#map6 = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
#map7 = affine_map<(d0, d1, d2) -> (d0, d2)>
#map8 = affine_map<(d0, d1) -> (d1, d0)>
#map9 = affine_map<(d0, d1, d2) -> (d2)>
util.global private @hoisted_1 : tensor<768x30522xf32>
util.global private @hoisted_0 : tensor<1x512x768xf32>
util.global private @hoisted : tensor<1x512x768xf32>
flow.executable private @_initializer_0_dispatch_0 {
flow.executable.export public @_initializer_0_dispatch_0_generic_1x512x768_i32xf32 workgroups() -> (index, index, index) {
%x, %y, %z = flow.dispatch.workgroup_count_from_slice
flow.return %x, %y, %z : index, index, index
}
builtin.module {
func.func @_initializer_0_dispatch_0_generic_1x512x768_i32xf32(%arg0: !flow.dispatch.tensor<readonly:tensor<512x768xf32>>, %arg1: !flow.dispatch.tensor<readonly:tensor<1x512xi32>>, %arg2: !flow.dispatch.tensor<writeonly:tensor<1x512x768xf32>>) {
%0 = flow.dispatch.tensor.load %arg0, offsets = [0, 0], sizes = [512, 768], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<512x768xf32>> -> tensor<512x768xf32>
%1 = flow.dispatch.tensor.load %arg1, offsets = [0, 0], sizes = [1, 512], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<1x512xi32>> -> tensor<1x512xi32>
%2 = tensor.empty() : tensor<1x512x768xf32>
%3 = linalg.generic {indexing_maps = [#map1, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1 : tensor<1x512xi32>) outs(%2 : tensor<1x512x768xf32>) {
^bb0(%in: i32, %out: f32):
%4 = arith.index_cast %in : i32 to index
%5 = linalg.index 2 : index
%extracted = tensor.extract %0[%4, %5] : tensor<512x768xf32>
linalg.yield %extracted : f32
} -> tensor<1x512x768xf32>
flow.dispatch.tensor.store %3, %arg2, offsets = [0, 0, 0], sizes = [1, 512, 768], strides = [1, 1, 1] : tensor<1x512x768xf32> -> !flow.dispatch.tensor<writeonly:tensor<1x512x768xf32>>
return
}
}
}
util.initializer attributes {iree.compiler.consteval} {
%cst = arith.constant dense_resource<__elided__> : tensor<1x512xi32>
%cst_0 = arith.constant dense_resource<__elided__> : tensor<512x768xf32>
%0 = flow.dispatch @_initializer_0_dispatch_0::@_initializer_0_dispatch_0_generic_1x512x768_i32xf32(%cst_0, %cst) : (tensor<512x768xf32>, tensor<1x512xi32>) -> tensor<1x512x768xf32>
util.global.store %0, @hoisted : tensor<1x512x768xf32>
util.initializer.return
}
flow.executable private @_initializer_1_dispatch_0 {
flow.executable.export public @_initializer_1_dispatch_0_generic_1x512x768_f32 workgroups() -> (index, index, index) {
%x, %y, %z = flow.dispatch.workgroup_count_from_slice
flow.return %x, %y, %z : index, index, index
}
builtin.module {
func.func @_initializer_1_dispatch_0_generic_1x512x768_f32(%arg0: !flow.dispatch.tensor<readonly:tensor<2x768xf32>>, %arg1: !flow.dispatch.tensor<writeonly:tensor<1x512x768xf32>>) {
%c0 = arith.constant 0 : index
%0 = flow.dispatch.tensor.load %arg0, offsets = [0, 0], sizes = [2, 768], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<2x768xf32>> -> tensor<2x768xf32>
%1 = tensor.empty() : tensor<1x512x768xf32>
%2 = linalg.generic {indexing_maps = [#map2], iterator_types = ["parallel", "parallel", "parallel"]} outs(%1 : tensor<1x512x768xf32>) {
^bb0(%out: f32):
%3 = linalg.index 2 : index
%extracted = tensor.extract %0[%c0, %3] : tensor<2x768xf32>
linalg.yield %extracted : f32
} -> tensor<1x512x768xf32>
flow.dispatch.tensor.store %2, %arg1, offsets = [0, 0, 0], sizes = [1, 512, 768], strides = [1, 1, 1] : tensor<1x512x768xf32> -> !flow.dispatch.tensor<writeonly:tensor<1x512x768xf32>>
return
}
}
}
util.initializer attributes {iree.compiler.consteval} {
%cst = arith.constant dense_resource<__elided__> : tensor<2x768xf32>
%0 = flow.dispatch @_initializer_1_dispatch_0::@_initializer_1_dispatch_0_generic_1x512x768_f32(%cst) : (tensor<2x768xf32>) -> tensor<1x512x768xf32>
util.global.store %0, @hoisted_0 : tensor<1x512x768xf32>
util.initializer.return
}
flow.executable private @_initializer_2_dispatch_0 {
flow.executable.export public @_initializer_2_dispatch_0_generic_768x30522_f32 workgroups() -> (index, index, index) {
%x, %y, %z = flow.dispatch.workgroup_count_from_slice
flow.return %x, %y, %z : index, index, index
}
builtin.module {
func.func @_initializer_2_dispatch_0_generic_768x30522_f32(%arg0: !flow.dispatch.tensor<readonly:tensor<30522x768xf32>>, %arg1: !flow.dispatch.tensor<writeonly:tensor<768x30522xf32>>) {
%0 = flow.dispatch.tensor.load %arg0, offsets = [0, 0], sizes = [30522, 768], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<30522x768xf32>> -> tensor<30522x768xf32>
%1 = tensor.empty() : tensor<768x30522xf32>
%2 = linalg.generic {indexing_maps = [#map8, #map3], iterator_types = ["parallel", "parallel"]} ins(%0 : tensor<30522x768xf32>) outs(%1 : tensor<768x30522xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<768x30522xf32>
flow.dispatch.tensor.store %2, %arg1, offsets = [0, 0], sizes = [768, 30522], strides = [1, 1] : tensor<768x30522xf32> -> !flow.dispatch.tensor<writeonly:tensor<768x30522xf32>>
return
}
}
}
util.initializer attributes {iree.compiler.consteval} {
%cst = arith.constant dense_resource<__elided__> : tensor<30522x768xf32>
%0 = flow.dispatch @_initializer_2_dispatch_0::@_initializer_2_dispatch_0_generic_768x30522_f32(%cst) : (tensor<30522x768xf32>) -> tensor<768x30522xf32>
util.global.store %0, @hoisted_1 : tensor<768x30522xf32>
util.initializer.return
}
and when enabled these get folded:
util.global private @hoisted_1 = dense_resource<__elided__> : tensor<768x30522xf32>
util.global private @hoisted_0 = dense_resource<__elided__> : tensor<1x512x768xf32>
util.global private @hoisted = dense_resource<__elided__> : tensor<1x512x768xf32>
@hoisted
and @hoisted_0
are sketchy - those feel like things that should be fused with consumers. @hoisted_0
for example takes a 2x768 input and produces a 512x768 output - we definitely don't want to be baking that out into files when it should instead likely just be a broadcasting load fused with consumers for cheap.
The bulk of the size gain comes from @hoisted_1
, though, which is just a transpose. It looks like there's a gather and a matmul that use the original constant, and the transpose comes from the matmul:
#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, 0, 0, d3)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
#map3 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map4 = affine_map<(d0, d1) -> (d0, d1)>
#map5 = affine_map<(d0, d1, d2) -> (d0, d1, 0)>
#map6 = affine_map<(d0, d1, d2) -> (d2)>
#map7 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>
#map8 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d2)>
#map9 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
#map10 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>
#map11 = affine_map<(d0, d1) -> (d1, d0)>
%cst_48 = arith.constant dense_resource<__elided__> : tensor<30522x768xf32>
%0 = hal.tensor.import %arg0 "input 0" : !hal.buffer_view -> tensor<1x512xi32>
%11 = linalg.generic {indexing_maps = [#map2, #map3], iterator_types = ["parallel", "parallel", "parallel"]} ins(%0 : tensor<1x512xi32>) outs(%8 : tensor<1x512x768xf32>) {
^bb0(%in: i32, %out: f32):
%1307 = arith.index_cast %in : i32 to index
%1308 = linalg.index 2 : index
%extracted = tensor.extract %cst_48[%1307, %1308] : tensor<30522x768xf32>
linalg.yield %extracted : f32
} -> tensor<1x512x768xf32>
%collapsed_553 = tensor.collapse_shape %1297 [[0, 1], [2]] : tensor<1x512x768xf32> into tensor<512x768xf32>
%1298 = tensor.empty() : tensor<768x30522xf32>
%1299 = linalg.generic {indexing_maps = [#map11, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_48 : tensor<30522x768xf32>) outs(%1298 : tensor<768x30522xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<768x30522xf32>
%1300 = tensor.empty() : tensor<512x30522xf32>
%1301 = linalg.fill ins(%cst_0 : f32) outs(%1300 : tensor<512x30522xf32>) -> tensor<512x30522xf32>
%1302 = linalg.matmul ins(%collapsed_553, %1299 : tensor<512x768xf32>, tensor<768x30522xf32>) outs(%1301 : tensor<512x30522xf32>) -> tensor<512x30522xf32>
This looks like something that some propagation of transposes and such should help with: if a tensor has multiple users and one doesn't care/can have its layout updated and another does we could hoist the layout change up - above the gather should be fine if it's tensor.extract %cst_48[%1308, %1307]
as it's just a map change.
I can look into this. I have learnt about how const_eval works recently, and they look interesting to me. I feel that we'd like to
14550 caused a 22% size regression, meaning that there's likely some reduction that FusionOfTensors was handling that now isn't. Would be good to understand what's missed so we can track improving it.
Compilation flags: