Open benvanik opened 1 year ago
I think that is addressed by setting the flag --iree-flow-enable-fuse-padding-into-linalg-consumer-ops
. That should fuse the pad with its consumer. It is off by default cause it works on CPU and SPIR-V backends only, and also I would rather fuse pad with its producers. #10184 was something I tried as an example for someone to pick it up, but it hasnt been picked up by anyone so far. So it remains on my backlog.
I can close the old bug in favor of this one if that helps... but we have discussed this many times ;) , I wont be able to push on this, and really hoping someone can pick this up.
Yep, just trying to increase visibility and have something to point at for people asking what's up with the bubbles in the pipeline :) It may be useful to add some notes to #10184 with what still needs to be done. The original #2783 didn't have an example - we can dedupe this against that so that whoever ends up fixing this gets the honor of closing a 2+yr old issue :)
I need to resurrect that PR... Will do that and add some notes.
ESRGAN suffers from this as well and would benefit from padding propagation into consumers.
Example where if the padding in %padded_923
was propagated upward we could perform the elementwise ops in-place and then directly consume them in the conv.
// and other elementwise ops producing %189/%192/etc
%198 = linalg.generic {indexing_maps = [#map2, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%197 : tensor<1x32x90x62xf32>) outs(%3 : tensor<1x32x90x62xf32>) {
^bb0(%in: f32, %out: f32):
%1177 = arith.cmpf ugt, %in, %cst_701 : f32
%1178 = arith.select %1177, %in, %cst_701 : f32
%1179 = arith.select %1177, %cst_701, %in : f32
%1180 = arith.truncf %cst_702 : f64 to f32
%1181 = arith.mulf %1179, %1180 : f32
%1182 = arith.addf %1178, %1181 : f32
linalg.yield %1182 : f32
} -> tensor<1x32x90x62xf32>
%inserted_slice_919 = tensor.insert_slice %189 into %15[0, 0, 0, 0] [1, 64, 90, 62] [1, 1, 1, 1] : tensor<1x64x90x62xf32> into tensor<1x160x90x62xf32>
%inserted_slice_920 = tensor.insert_slice %192 into %inserted_slice_919[0, 64, 0, 0] [1, 32, 90, 62] [1, 1, 1, 1] : tensor<1x32x90x62xf32> into tensor<1x160x90x62xf32>
%inserted_slice_921 = tensor.insert_slice %195 into %inserted_slice_920[0, 96, 0, 0] [1, 32, 90, 62] [1, 1, 1, 1] : tensor<1x32x90x62xf32> into tensor<1x160x90x62xf32>
%inserted_slice_922 = tensor.insert_slice %198 into %inserted_slice_921[0, 128, 0, 0] [1, 32, 90, 62] [1, 1, 1, 1] : tensor<1x32x90x62xf32> into tensor<1x160x90x62xf32>
%padded_923 = tensor.pad %inserted_slice_922 low[0, 0, 1, 1] high[0, 0, 1, 1] {
^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index):
tensor.yield %cst_701 : f32
} : tensor<1x160x90x62xf32> to tensor<1x160x92x64xf32>
%199 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%cst_581 : tensor<32xf32>) outs(%3 : tensor<1x32x90x62xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x90x62xf32>
%200 = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%padded_923, %cst_582 : tensor<1x160x92x64xf32>, tensor<32x160x3x3xf32>) outs(%199 : tensor<1x32x90x62xf32>) -> tensor<1x32x90x62xf32>
%201 = linalg.generic {indexing_maps = [#map2, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%200 : tensor<1x32x90x62xf32>) outs(%3 : tensor<1x32x90x62xf32>) {
^bb0(%in: f32, %out: f32):
%1177 = arith.cmpf ugt, %in, %cst_701 : f32
%1178 = arith.select %1177, %in, %cst_701 : f32
%1179 = arith.select %1177, %cst_701, %in : f32
%1180 = arith.truncf %cst_702 : f64 to f32
%1181 = arith.mulf %1179, %1180 : f32
%1182 = arith.addf %1178, %1181 : f32
linalg.yield %1182 : f32
} -> tensor<1x32x90x62xf32>
Today we end up with this which requires the %39 splat and %41 dispatch copy:
%32 = flow.dispatch.workgroups[%c1, %c32, %c90, %c62](%31, %cst_4) : (tensor<1x160x92x64xf32>, tensor<32x160x3x3xf32>) -> tensor<1x32x90x62xf32> =
(%arg3: !flow.dispatch.tensor<readonly:tensor<1x160x92x64xf32>>, %arg4: !flow.dispatch.tensor<readonly:tensor<32x160x3x3xf32>>, %arg5: !flow.dispatch.tensor<writeonly:tensor<1x32x90x62xf32>>) {
%cst_351 = arith.constant 0.000000e+00 : f32
%cst_352 = arith.constant 0.199999988 : f32
%cst_353 = arith.constant dense<[[-0.00735217752, -0.029075671, -0.0011687536, -0.0265800748, -0.016661156, -0.0216491632, -0.0427877456, -0.0533559099, -0.0249305591, -0.0207087267, -0.0253318828, -0.0515014119, -0.0422265045, -0.0368615724, 0.00198965892, -0.0221594162, -0.0266306344, -0.0617676973, -0.0261138938, -0.00482901605, -0.0400608778, -0.0137573751, -0.00975679792, -0.0443469957, -0.0315653086, -0.0245542042, -0.0320154652, -6.253720e-02, -0.0274252892, 0.00514560752, -0.0166819859, -0.0136556849]]> : tensor<1x32xf32>
%2035 = flow.dispatch.tensor.load %arg3, offsets = [0, 0, 0, 0], sizes = [1, 160, 92, 64], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<1x160x92x64xf32>> -> tensor<1x160x92x64xf32>
%2036 = flow.dispatch.tensor.load %arg4, offsets = [0, 0, 0, 0], sizes = [32, 160, 3, 3], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<32x160x3x3xf32>> -> tensor<32x160x3x3xf32>
%2037 = tensor.empty() : tensor<1x32x90x62xf32>
%2038 = linalg.fill ins(%cst_351 : f32) outs(%2037 : tensor<1x32x90x62xf32>) -> tensor<1x32x90x62xf32>
%2039 = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%2035, %2036 : tensor<1x160x92x64xf32>, tensor<32x160x3x3xf32>) outs(%2038 : tensor<1x32x90x62xf32>) -> tensor<1x32x90x62xf32>
%2040 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%2039, %cst_353 : tensor<1x32x90x62xf32>, tensor<1x32xf32>) outs(%2037 : tensor<1x32x90x62xf32>) {
^bb0(%in: f32, %in_354: f32, %out: f32):
%2041 = arith.addf %in, %in_354 : f32
%2042 = arith.cmpf ugt, %2041, %cst_351 : f32
%2043 = arith.select %2042, %2041, %cst_351 : f32
%2044 = arith.select %2042, %cst_351, %2041 : f32
%2045 = arith.mulf %2044, %cst_352 : f32
%2046 = arith.addf %2043, %2045 : f32
linalg.yield %2046 : f32
} -> tensor<1x32x90x62xf32>
flow.dispatch.tensor.store %2040, %arg5, offsets = [0, 0, 0, 0], sizes = [1, 32, 90, 62], strides = [1, 1, 1, 1] : tensor<1x32x90x62xf32> -> !flow.dispatch.tensor<writeonly:tensor<1x32x90x62xf32>>
flow.return
} count(%arg3: index, %arg4: index, %arg5: index, %arg6: index) -> (index, index, index) {
%x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg3, %arg4, %arg5, %arg6
flow.return %x, %y, %z : index, index, index
}
%33 = tensor.empty() : tensor<1x192x90x62xf32>
%34 = flow.tensor.update %4, %33[%c0, %c0, %c0, %c0] : tensor<1x64x90x62xf32> -> %33 as tensor<1x192x90x62xf32>
%35 = flow.tensor.update %8, %34[%c0, %c64, %c0, %c0] : tensor<1x32x90x62xf32> -> %34 as tensor<1x192x90x62xf32>
%36 = flow.tensor.update %15, %35[%c0, %c96, %c0, %c0] : tensor<1x32x90x62xf32> -> %35 as tensor<1x192x90x62xf32>
%37 = flow.tensor.update %23, %36[%c0, %c128, %c0, %c0] : tensor<1x32x90x62xf32> -> %36 as tensor<1x192x90x62xf32>
%38 = flow.tensor.update %32, %37[%c0, %c160, %c0, %c0] : tensor<1x32x90x62xf32> -> %37 as tensor<1x192x90x62xf32>
%39 = flow.tensor.splat %cst : tensor<1x192x92x64xf32>
%40 = flow.tensor.reshape %38 : tensor<1x192x90x62xf32> -> tensor<192x90x62xf32>
%41 = flow.dispatch.workgroups[%c192, %c90, %c62](%40, %39) : (tensor<192x90x62xf32>, tensor<1x192x92x64xf32>) -> %39 =
(%arg3: !flow.dispatch.tensor<readonly:tensor<192x90x62xf32>>, %arg4: !flow.dispatch.tensor<readwrite:tensor<1x192x92x64xf32>>) {
%2035 = flow.dispatch.tensor.load %arg3, offsets = [0, 0, 0], sizes = [192, 90, 62], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<192x90x62xf32>> -> tensor<192x90x62xf32>
flow.dispatch.tensor.store %2035, %arg4, offsets = [0, 0, 1, 1], sizes = [1, 192, 90, 62], strides = [1, 1, 1, 1] : tensor<192x90x62xf32> -> !flow.dispatch.tensor<readwrite:tensor<1x192x92x64xf32>>
flow.return
} count(%arg3: index, %arg4: index, %arg5: index) -> (index, index, index) {
%x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg3, %arg4, %arg5
flow.return %x, %y, %z : index, index, index
}
Looking at all the splats - from both this and #6972 - we're doing 2004094976 (2GB!!!!) of memset(0)'s in ESRGAN.
As an example, in just the last few stages of ESRGAN we're memset(0)'ing 100MB:
I tried the --iree-flow-enable-fuse-padding-into-linalg-consumer-ops
flag with ESRGAN and it got rid of the fills (so 1.1GB/invocation of memset(0)) and dropped transient memory usage from 74905600 (74MB) to 61424640 (61MB) - nice! (#7729 will drop it another 5-10MB)
I haven't actually benchmarked the model but it's nice to know what overheads padding adds here.
Was looking at
tests/e2e/models/resnet50_fake_weights.mlir
and noticed that there are still a lot of fills/slow dispatch-based memcpys (~18 fills/dispatches and unique executables per each because of the unique sizes). This adds quite a bit of latency to the system as the fill -> dispatch that does just memcpy -> actual consumer are serialized. Thankfully we can run the fill concurrently with the producer but that is a large additional transient value we need to allocate/keep live and still an extra 33% baseline latency ([producer|fill] -> pad dispatch -> consumer vs. producer -> consumer). 23% of the dispatches we compile/store in the binary/execute at runtime are these pads and a ~25% savings on that would be awesome. Now that we have some latency-sensitive models with convs (where I think we end up with the most pads) getting rid of this noise will help keep focus on the actual codegen improvements and not the dispatch scheduling.I think #9194 was supposed to prevent this, but there's also a draft #10184 that may have intended to do it. Fixing this would let us finally close the old #2783. Feel free to close as a dupe or consider this a ping with an easily available reproducer :)
What this looks like during execution is (with dispatch_6 as the serialized pad):
Ideally we'd just see dispatch_9 -> dispatch_7 (matmul -> conv) with no intervening ops.