iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.8k stars 604 forks source link

Fills/dispatches when padding not getting folded into consumers/producers. #11049

Open benvanik opened 1 year ago

benvanik commented 1 year ago

Was looking at tests/e2e/models/resnet50_fake_weights.mlir and noticed that there are still a lot of fills/slow dispatch-based memcpys (~18 fills/dispatches and unique executables per each because of the unique sizes). This adds quite a bit of latency to the system as the fill -> dispatch that does just memcpy -> actual consumer are serialized. Thankfully we can run the fill concurrently with the producer but that is a large additional transient value we need to allocate/keep live and still an extra 33% baseline latency ([producer|fill] -> pad dispatch -> consumer vs. producer -> consumer). 23% of the dispatches we compile/store in the binary/execute at runtime are these pads and a ~25% savings on that would be awesome. Now that we have some latency-sensitive models with convs (where I think we end up with the most pads) getting rid of this noise will help keep focus on the actual codegen improvements and not the dispatch scheduling.

I think #9194 was supposed to prevent this, but there's also a draft #10184 that may have intended to do it. Fixing this would let us finally close the old #2783. Feel free to close as a dupe or consider this a ping with an easily available reproducer :)

What this looks like during execution is (with dispatch_6 as the serialized pad):

stream.executable private @predict_dispatch_6 {
  stream.executable.export public @predict_dispatch_6 workgroups(%arg0: index, %arg1: index, %arg2: index) -> (index, index, index) {
    %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0, %arg1, %arg2
    stream.return %x, %y, %z : index, index, index
  }
  builtin.module {
    func.func @predict_dispatch_6(%arg0: !stream.binding, %arg1: !stream.binding) {
      %c0 = arith.constant 0 : index
      %0 = stream.binding.subspan %arg0[%c0] : !stream.binding -> !flow.dispatch.tensor<readonly:tensor<56x56x64xf32>>
      %1 = stream.binding.subspan %arg1[%c0] : !stream.binding -> !flow.dispatch.tensor<readwrite:tensor<1x58x58x64xf32>>
      %2 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [56, 56, 64], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<56x56x64xf32>> -> tensor<56x56x64xf32>
      flow.dispatch.tensor.store %2, %1, offsets = [0, 1, 1, 0], sizes = [1, 56, 56, 64], strides = [1, 1, 1, 1] : tensor<56x56x64xf32> -> !flow.dispatch.tensor<readwrite:tensor<1x58x58x64xf32>>
      return
    }
  }
}

%17:2 = stream.async.concurrent with(%16 as %arg322: !stream.resource<transient>{%c3211264}, %arg50 as %arg323: !stream.resource<constant>{%c65536}, %arg51 as %arg324: !stream.resource<constant>{%c256}, %arg52 as %arg325: !stream.resource<constant>{%c256}, %arg53 as %arg326: !stream.resource<constant>{%c256}, %arg54 as %arg327: !stream.resource<constant>{%c256}, %arg55 as %arg328: !stream.resource<constant>{%c256}) -> (!stream.resource<transient>{%c802816}, !stream.resource<transient>{%c861184}) {
  %79 = stream.async.dispatch @predict_dispatch_9::@predict_dispatch_9_matmul_3136x64x256[%c3136, %c64](%arg322, %arg323, %arg324, %arg325, %arg326, %arg327, %arg328) : (!stream.resource<transient>{%c3211264}, !stream.resource<constant>{%c65536}, !stream.resource<constant>{%c256}, !stream.resource<constant>{%c256}, !stream.resource<constant>{%c256}, !stream.resource<constant>{%c256}, !stream.resource<constant>{%c256}) -> !stream.resource<transient>{%c802816}
  %80 = stream.async.splat %c0_i8 : i8 -> !stream.resource<transient>{%c861184}
  stream.yield %79, %80 : !stream.resource<transient>{%c802816}, !stream.resource<transient>{%c861184}
}
%18 = stream.async.dispatch @predict_dispatch_6::@predict_dispatch_6[%c56, %c56, %c64](%17#0, %17#1) : (!stream.resource<transient>{%c802816}, !stream.resource<transient>{%c861184}) -> %17#1{%c861184}
%19 = stream.async.dispatch @predict_dispatch_7::@predict_dispatch_7_conv_2d_nhwc_hwcf_1x56x56x64x3x3x64[%c1, %c56, %c56, %c64](%18, %arg56, %arg57, %arg58, %arg59, %arg60, %arg61) : (!stream.resource<transient>{%c861184}, !stream.resource<constant>{%c147456}, !stream.resource<constant>{%c256}, !stream.resource<constant>{%c256}, !stream.resource<constant>{%c256}, !stream.resource<constant>{%c256}, !stream.resource<constant>{%c256}) -> !stream.resource<transient>{%c802816}

Ideally we'd just see dispatch_9 -> dispatch_7 (matmul -> conv) with no intervening ops.

MaheshRavishankar commented 1 year ago

I think that is addressed by setting the flag --iree-flow-enable-fuse-padding-into-linalg-consumer-ops. That should fuse the pad with its consumer. It is off by default cause it works on CPU and SPIR-V backends only, and also I would rather fuse pad with its producers. #10184 was something I tried as an example for someone to pick it up, but it hasnt been picked up by anyone so far. So it remains on my backlog. I can close the old bug in favor of this one if that helps... but we have discussed this many times ;) , I wont be able to push on this, and really hoping someone can pick this up.

benvanik commented 1 year ago

Yep, just trying to increase visibility and have something to point at for people asking what's up with the bubbles in the pipeline :) It may be useful to add some notes to #10184 with what still needs to be done. The original #2783 didn't have an example - we can dedupe this against that so that whoever ends up fixing this gets the honor of closing a 2+yr old issue :)

MaheshRavishankar commented 1 year ago

I need to resurrect that PR... Will do that and add some notes.

benvanik commented 1 year ago

ESRGAN suffers from this as well and would benefit from padding propagation into consumers.

Example where if the padding in %padded_923 was propagated upward we could perform the elementwise ops in-place and then directly consume them in the conv.

    // and other elementwise ops producing %189/%192/etc
    %198 = linalg.generic {indexing_maps = [#map2, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%197 : tensor<1x32x90x62xf32>) outs(%3 : tensor<1x32x90x62xf32>) {
    ^bb0(%in: f32, %out: f32):
      %1177 = arith.cmpf ugt, %in, %cst_701 : f32
      %1178 = arith.select %1177, %in, %cst_701 : f32
      %1179 = arith.select %1177, %cst_701, %in : f32
      %1180 = arith.truncf %cst_702 : f64 to f32
      %1181 = arith.mulf %1179, %1180 : f32
      %1182 = arith.addf %1178, %1181 : f32
      linalg.yield %1182 : f32
    } -> tensor<1x32x90x62xf32>
    %inserted_slice_919 = tensor.insert_slice %189 into %15[0, 0, 0, 0] [1, 64, 90, 62] [1, 1, 1, 1] : tensor<1x64x90x62xf32> into tensor<1x160x90x62xf32>
    %inserted_slice_920 = tensor.insert_slice %192 into %inserted_slice_919[0, 64, 0, 0] [1, 32, 90, 62] [1, 1, 1, 1] : tensor<1x32x90x62xf32> into tensor<1x160x90x62xf32>
    %inserted_slice_921 = tensor.insert_slice %195 into %inserted_slice_920[0, 96, 0, 0] [1, 32, 90, 62] [1, 1, 1, 1] : tensor<1x32x90x62xf32> into tensor<1x160x90x62xf32>
    %inserted_slice_922 = tensor.insert_slice %198 into %inserted_slice_921[0, 128, 0, 0] [1, 32, 90, 62] [1, 1, 1, 1] : tensor<1x32x90x62xf32> into tensor<1x160x90x62xf32>
    %padded_923 = tensor.pad %inserted_slice_922 low[0, 0, 1, 1] high[0, 0, 1, 1] {
    ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index):
      tensor.yield %cst_701 : f32
    } : tensor<1x160x90x62xf32> to tensor<1x160x92x64xf32>
    %199 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%cst_581 : tensor<32xf32>) outs(%3 : tensor<1x32x90x62xf32>) {
    ^bb0(%in: f32, %out: f32):
      linalg.yield %in : f32
    } -> tensor<1x32x90x62xf32>
    %200 = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%padded_923, %cst_582 : tensor<1x160x92x64xf32>, tensor<32x160x3x3xf32>) outs(%199 : tensor<1x32x90x62xf32>) -> tensor<1x32x90x62xf32>
    %201 = linalg.generic {indexing_maps = [#map2, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%200 : tensor<1x32x90x62xf32>) outs(%3 : tensor<1x32x90x62xf32>) {
    ^bb0(%in: f32, %out: f32):
      %1177 = arith.cmpf ugt, %in, %cst_701 : f32
      %1178 = arith.select %1177, %in, %cst_701 : f32
      %1179 = arith.select %1177, %cst_701, %in : f32
      %1180 = arith.truncf %cst_702 : f64 to f32
      %1181 = arith.mulf %1179, %1180 : f32
      %1182 = arith.addf %1178, %1181 : f32
      linalg.yield %1182 : f32
    } -> tensor<1x32x90x62xf32>

Today we end up with this which requires the %39 splat and %41 dispatch copy:

  %32 = flow.dispatch.workgroups[%c1, %c32, %c90, %c62](%31, %cst_4) : (tensor<1x160x92x64xf32>, tensor<32x160x3x3xf32>) -> tensor<1x32x90x62xf32> =
      (%arg3: !flow.dispatch.tensor<readonly:tensor<1x160x92x64xf32>>, %arg4: !flow.dispatch.tensor<readonly:tensor<32x160x3x3xf32>>, %arg5: !flow.dispatch.tensor<writeonly:tensor<1x32x90x62xf32>>) {
    %cst_351 = arith.constant 0.000000e+00 : f32
    %cst_352 = arith.constant 0.199999988 : f32
    %cst_353 = arith.constant dense<[[-0.00735217752, -0.029075671, -0.0011687536, -0.0265800748, -0.016661156, -0.0216491632, -0.0427877456, -0.0533559099, -0.0249305591, -0.0207087267, -0.0253318828, -0.0515014119, -0.0422265045, -0.0368615724, 0.00198965892, -0.0221594162, -0.0266306344, -0.0617676973, -0.0261138938, -0.00482901605, -0.0400608778, -0.0137573751, -0.00975679792, -0.0443469957, -0.0315653086, -0.0245542042, -0.0320154652, -6.253720e-02, -0.0274252892, 0.00514560752, -0.0166819859, -0.0136556849]]> : tensor<1x32xf32>
    %2035 = flow.dispatch.tensor.load %arg3, offsets = [0, 0, 0, 0], sizes = [1, 160, 92, 64], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<1x160x92x64xf32>> -> tensor<1x160x92x64xf32>
    %2036 = flow.dispatch.tensor.load %arg4, offsets = [0, 0, 0, 0], sizes = [32, 160, 3, 3], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<32x160x3x3xf32>> -> tensor<32x160x3x3xf32>
    %2037 = tensor.empty() : tensor<1x32x90x62xf32>
    %2038 = linalg.fill ins(%cst_351 : f32) outs(%2037 : tensor<1x32x90x62xf32>) -> tensor<1x32x90x62xf32>
    %2039 = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%2035, %2036 : tensor<1x160x92x64xf32>, tensor<32x160x3x3xf32>) outs(%2038 : tensor<1x32x90x62xf32>) -> tensor<1x32x90x62xf32>
    %2040 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%2039, %cst_353 : tensor<1x32x90x62xf32>, tensor<1x32xf32>) outs(%2037 : tensor<1x32x90x62xf32>) {
    ^bb0(%in: f32, %in_354: f32, %out: f32):
      %2041 = arith.addf %in, %in_354 : f32
      %2042 = arith.cmpf ugt, %2041, %cst_351 : f32
      %2043 = arith.select %2042, %2041, %cst_351 : f32
      %2044 = arith.select %2042, %cst_351, %2041 : f32
      %2045 = arith.mulf %2044, %cst_352 : f32
      %2046 = arith.addf %2043, %2045 : f32
      linalg.yield %2046 : f32
    } -> tensor<1x32x90x62xf32>
    flow.dispatch.tensor.store %2040, %arg5, offsets = [0, 0, 0, 0], sizes = [1, 32, 90, 62], strides = [1, 1, 1, 1] : tensor<1x32x90x62xf32> -> !flow.dispatch.tensor<writeonly:tensor<1x32x90x62xf32>>
    flow.return
  } count(%arg3: index, %arg4: index, %arg5: index, %arg6: index) -> (index, index, index) {
    %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg3, %arg4, %arg5, %arg6
    flow.return %x, %y, %z : index, index, index
  }
  %33 = tensor.empty() : tensor<1x192x90x62xf32>
  %34 = flow.tensor.update %4, %33[%c0, %c0, %c0, %c0] : tensor<1x64x90x62xf32> -> %33 as tensor<1x192x90x62xf32>
  %35 = flow.tensor.update %8, %34[%c0, %c64, %c0, %c0] : tensor<1x32x90x62xf32> -> %34 as tensor<1x192x90x62xf32>
  %36 = flow.tensor.update %15, %35[%c0, %c96, %c0, %c0] : tensor<1x32x90x62xf32> -> %35 as tensor<1x192x90x62xf32>
  %37 = flow.tensor.update %23, %36[%c0, %c128, %c0, %c0] : tensor<1x32x90x62xf32> -> %36 as tensor<1x192x90x62xf32>
  %38 = flow.tensor.update %32, %37[%c0, %c160, %c0, %c0] : tensor<1x32x90x62xf32> -> %37 as tensor<1x192x90x62xf32>
  %39 = flow.tensor.splat %cst : tensor<1x192x92x64xf32>
  %40 = flow.tensor.reshape %38 : tensor<1x192x90x62xf32> -> tensor<192x90x62xf32>
  %41 = flow.dispatch.workgroups[%c192, %c90, %c62](%40, %39) : (tensor<192x90x62xf32>, tensor<1x192x92x64xf32>) -> %39 =
      (%arg3: !flow.dispatch.tensor<readonly:tensor<192x90x62xf32>>, %arg4: !flow.dispatch.tensor<readwrite:tensor<1x192x92x64xf32>>) {
    %2035 = flow.dispatch.tensor.load %arg3, offsets = [0, 0, 0], sizes = [192, 90, 62], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<192x90x62xf32>> -> tensor<192x90x62xf32>
    flow.dispatch.tensor.store %2035, %arg4, offsets = [0, 0, 1, 1], sizes = [1, 192, 90, 62], strides = [1, 1, 1, 1] : tensor<192x90x62xf32> -> !flow.dispatch.tensor<readwrite:tensor<1x192x92x64xf32>>
    flow.return
  } count(%arg3: index, %arg4: index, %arg5: index) -> (index, index, index) {
    %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg3, %arg4, %arg5
    flow.return %x, %y, %z : index, index, index
  }
benvanik commented 1 year ago

Looking at all the splats - from both this and #6972 - we're doing 2004094976 (2GB!!!!) of memset(0)'s in ESRGAN.

As an example, in just the last few stages of ESRGAN we're memset(0)'ing 100MB: image

benvanik commented 1 year ago

I tried the --iree-flow-enable-fuse-padding-into-linalg-consumer-ops flag with ESRGAN and it got rid of the fills (so 1.1GB/invocation of memset(0)) and dropped transient memory usage from 74905600 (74MB) to 61424640 (61MB) - nice! (#7729 will drop it another 5-10MB) I haven't actually benchmarked the model but it's nice to know what overheads padding adds here.