iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.51k stars 558 forks source link

Fuse multi-consumer insert_slices into dispatch regions as in-place operations. #11102

Open benvanik opened 1 year ago

benvanik commented 1 year ago

Certain models use a pattern where they produce a result and then insert that into multiple tensors. Today these are lowered down to transient allocations + DMA copies but when doing these basic tensor broadcast operations it'd be better for the producer dispatch to do multiple writes. Doing this would reduce transient memory requirements (as we'd be placing the N copies of the output directly into their destination and eliminate the transient memory) and break the serialized dependency chain (subsequent dispatches could begin executing earlier as they don't need to wait for the copy to complete).

This may be related to #10840 (multi-result fusion).

From ESRGAN, note that %189 is produced and then inserted into 4 different tensors:

    %189 = linalg.generic {indexing_maps = [#map2, #map2, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%188, %173 : tensor<1x64x90x62xf32>, tensor<1x64x90x62xf32>) outs(%0 : tensor<1x64x90x62xf32>) {
    ^bb0(%in: f32, %in_2018: f32, %out: f32):
      %1177 = arith.addf %in, %in_2018 : f32
      linalg.yield %1177 : f32
    } -> tensor<1x64x90x62xf32>
    %padded_911 = tensor.pad %189 low[0, 0, 1, 1] high[0, 0, 1, 1] {
    ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index):
      tensor.yield %cst_701 : f32
    } : tensor<1x64x90x62xf32> to tensor<1x64x92x64xf32>
    %190 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%cst_587 : tensor<32xf32>) outs(%3 : tensor<1x32x90x62xf32>) {
    ^bb0(%in: f32, %out: f32):
      linalg.yield %in : f32
    } -> tensor<1x32x90x62xf32>
    %191 = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%padded_911, %cst_588 : tensor<1x64x92x64xf32>, tensor<32x64x3x3xf32>) outs(%190 : tensor<1x32x90x62xf32>) -> tensor<1x32x90x62xf32>
    %192 = linalg.generic {indexing_maps = [#map2, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%191 : tensor<1x32x90x62xf32>) outs(%3 : tensor<1x32x90x62xf32>) {
    ^bb0(%in: f32, %out: f32):
      %1177 = arith.cmpf ugt, %in, %cst_701 : f32
      %1178 = arith.select %1177, %in, %cst_701 : f32
      %1179 = arith.select %1177, %cst_701, %in : f32
      %1180 = arith.truncf %cst_702 : f64 to f32
      %1181 = arith.mulf %1179, %1180 : f32
      %1182 = arith.addf %1178, %1181 : f32
      linalg.yield %1182 : f32
    } -> tensor<1x32x90x62xf32>
    %inserted_slice_912 = tensor.insert_slice %189 into %7[0, 0, 0, 0] [1, 64, 90, 62] [1, 1, 1, 1] : tensor<1x64x90x62xf32> into tensor<1x96x90x62xf32>
    %inserted_slice_913 = tensor.insert_slice %192 into %inserted_slice_912[0, 64, 0, 0] [1, 32, 90, 62] [1, 1, 1, 1] : tensor<1x32x90x62xf32> into tensor<1x96x90x62xf32>
    %padded_914 = tensor.pad %inserted_slice_913 low[0, 0, 1, 1] high[0, 0, 1, 1] {
    ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index):
      tensor.yield %cst_701 : f32
    } : tensor<1x96x90x62xf32> to tensor<1x96x92x64xf32>
    %193 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%cst_585 : tensor<32xf32>) outs(%3 : tensor<1x32x90x62xf32>) {
    ^bb0(%in: f32, %out: f32):
      linalg.yield %in : f32
    } -> tensor<1x32x90x62xf32>
    %194 = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%padded_914, %cst_586 : tensor<1x96x92x64xf32>, tensor<32x96x3x3xf32>) outs(%193 : tensor<1x32x90x62xf32>) -> tensor<1x32x90x62xf32>
    %195 = linalg.generic {indexing_maps = [#map2, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%194 : tensor<1x32x90x62xf32>) outs(%3 : tensor<1x32x90x62xf32>) {
    ^bb0(%in: f32, %out: f32):
      %1177 = arith.cmpf ugt, %in, %cst_701 : f32
      %1178 = arith.select %1177, %in, %cst_701 : f32
      %1179 = arith.select %1177, %cst_701, %in : f32
      %1180 = arith.truncf %cst_702 : f64 to f32
      %1181 = arith.mulf %1179, %1180 : f32
      %1182 = arith.addf %1178, %1181 : f32
      linalg.yield %1182 : f32
    } -> tensor<1x32x90x62xf32>
    %inserted_slice_915 = tensor.insert_slice %189 into %11[0, 0, 0, 0] [1, 64, 90, 62] [1, 1, 1, 1] : tensor<1x64x90x62xf32> into tensor<1x128x90x62xf32>
    %inserted_slice_916 = tensor.insert_slice %192 into %inserted_slice_915[0, 64, 0, 0] [1, 32, 90, 62] [1, 1, 1, 1] : tensor<1x32x90x62xf32> into tensor<1x128x90x62xf32>
    %inserted_slice_917 = tensor.insert_slice %195 into %inserted_slice_916[0, 96, 0, 0] [1, 32, 90, 62] [1, 1, 1, 1] : tensor<1x32x90x62xf32> into tensor<1x128x90x62xf32>
    %padded_918 = tensor.pad %inserted_slice_917 low[0, 0, 1, 1] high[0, 0, 1, 1] {
    ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index):
      tensor.yield %cst_701 : f32
    } : tensor<1x128x90x62xf32> to tensor<1x128x92x64xf32>
    %196 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%cst_583 : tensor<32xf32>) outs(%3 : tensor<1x32x90x62xf32>) {
    ^bb0(%in: f32, %out: f32):
      linalg.yield %in : f32
    } -> tensor<1x32x90x62xf32>
    %197 = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%padded_918, %cst_584 : tensor<1x128x92x64xf32>, tensor<32x128x3x3xf32>) outs(%196 : tensor<1x32x90x62xf32>) -> tensor<1x32x90x62xf32>
    %198 = linalg.generic {indexing_maps = [#map2, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%197 : tensor<1x32x90x62xf32>) outs(%3 : tensor<1x32x90x62xf32>) {
    ^bb0(%in: f32, %out: f32):
      %1177 = arith.cmpf ugt, %in, %cst_701 : f32
      %1178 = arith.select %1177, %in, %cst_701 : f32
      %1179 = arith.select %1177, %cst_701, %in : f32
      %1180 = arith.truncf %cst_702 : f64 to f32
      %1181 = arith.mulf %1179, %1180 : f32
      %1182 = arith.addf %1178, %1181 : f32
      linalg.yield %1182 : f32
    } -> tensor<1x32x90x62xf32>
    %inserted_slice_919 = tensor.insert_slice %189 into %15[0, 0, 0, 0] [1, 64, 90, 62] [1, 1, 1, 1] : tensor<1x64x90x62xf32> into tensor<1x160x90x62xf32>
    %inserted_slice_920 = tensor.insert_slice %192 into %inserted_slice_919[0, 64, 0, 0] [1, 32, 90, 62] [1, 1, 1, 1] : tensor<1x32x90x62xf32> into tensor<1x160x90x62xf32>
    %inserted_slice_921 = tensor.insert_slice %195 into %inserted_slice_920[0, 96, 0, 0] [1, 32, 90, 62] [1, 1, 1, 1] : tensor<1x32x90x62xf32> into tensor<1x160x90x62xf32>
    %inserted_slice_922 = tensor.insert_slice %198 into %inserted_slice_921[0, 128, 0, 0] [1, 32, 90, 62] [1, 1, 1, 1] : tensor<1x32x90x62xf32> into tensor<1x160x90x62xf32>
    %padded_923 = tensor.pad %inserted_slice_922 low[0, 0, 1, 1] high[0, 0, 1, 1] {
    ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index):
      tensor.yield %cst_701 : f32
    } : tensor<1x160x90x62xf32> to tensor<1x160x92x64xf32>
    %199 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%cst_581 : tensor<32xf32>) outs(%3 : tensor<1x32x90x62xf32>) {
    ^bb0(%in: f32, %out: f32):
      linalg.yield %in : f32
    } -> tensor<1x32x90x62xf32>
    %200 = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%padded_923, %cst_582 : tensor<1x160x92x64xf32>, tensor<32x160x3x3xf32>) outs(%199 : tensor<1x32x90x62xf32>) -> tensor<1x32x90x62xf32>
    %201 = linalg.generic {indexing_maps = [#map2, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%200 : tensor<1x32x90x62xf32>) outs(%3 : tensor<1x32x90x62xf32>) {
    ^bb0(%in: f32, %out: f32):
      %1177 = arith.cmpf ugt, %in, %cst_701 : f32
      %1178 = arith.select %1177, %in, %cst_701 : f32
      %1179 = arith.select %1177, %cst_701, %in : f32
      %1180 = arith.truncf %cst_702 : f64 to f32
      %1181 = arith.mulf %1179, %1180 : f32
      %1182 = arith.addf %1178, %1181 : f32
      linalg.yield %1182 : f32
    } -> tensor<1x32x90x62xf32>
    %inserted_slice_924 = tensor.insert_slice %189 into %19[0, 0, 0, 0] [1, 64, 90, 62] [1, 1, 1, 1] : tensor<1x64x90x62xf32> into tensor<1x192x90x62xf32>
    %inserted_slice_925 = tensor.insert_slice %192 into %inserted_slice_924[0, 64, 0, 0] [1, 32, 90, 62] [1, 1, 1, 1] : tensor<1x32x90x62xf32> into tensor<1x192x90x62xf32>
    %inserted_slice_926 = tensor.insert_slice %195 into %inserted_slice_925[0, 96, 0, 0] [1, 32, 90, 62] [1, 1, 1, 1] : tensor<1x32x90x62xf32> into tensor<1x192x90x62xf32>
    %inserted_slice_927 = tensor.insert_slice %198 into %inserted_slice_926[0, 128, 0, 0] [1, 32, 90, 62] [1, 1, 1, 1] : tensor<1x32x90x62xf32> into tensor<1x192x90x62xf32>
    %inserted_slice_928 = tensor.insert_slice %201 into %inserted_slice_927[0, 160, 0, 0] [1, 32, 90, 62] [1, 1, 1, 1] : tensor<1x32x90x62xf32> into tensor<1x192x90x62xf32>

Since all of the target tensors (%7, %11, %15, etc) exist prior to all of this work we could instead have the %189 dispatch just insert into all 4 of them (by taking them as read/write IO).

benvanik commented 1 year ago

The original torch code from https://github.com/nod-ai/SHARK/pull/418/files#diff-f98c3ce6646546dce80ee6d6eca9f9537efcc70a300dd8df3b0f128e8bbb4316R40-R46:

    def forward(self, x):
        x1 = self.lrelu(self.conv1(x))
        x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
        x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
        x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
        x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
        return x5 * 0.2 + x
benvanik commented 1 year ago

Shady idea I could hack together: a flow-level pass that runs after dispatch workgroups formation that goes and inserts new results and replicates flow.dispatch.tensor.store ops pointing at them. This wouldn't generalize to conditional stores or anything and may make codegen bufferization unhappy (not sure how much it relies on single stores for allocation placement).

e.g.:

    %1529 = flow.dispatch.workgroups[%c1, %c32, %c90, %c62](%1528, %cst_262) : (tensor<1x96x92x64xf32>, tensor<32x96x3x3xf32>) -> tensor<1x32x90x62xf32> =
        (%arg3: !flow.dispatch.tensor<readonly:tensor<1x96x92x64xf32>>, %arg4: !flow.dispatch.tensor<readonly:tensor<32x96x3x3xf32>>, %arg5: !flow.dispatch.tensor<writeonly:tensor<1x32x90x62xf32>>) {
      %cst_351 = arith.constant 0.000000e+00 : f32
      %cst_352 = arith.constant 0.199999988 : f32
      %cst_353 = arith.constant dense<[[-0.104562163, -0.0794978737, -0.139019474, -0.0678559243, -0.0663451776, -0.0395833179, -0.139937162, -0.0967885255, -0.119102009, -0.138187289, -0.0833081305, -0.106967404, 0.0852515175, -0.0525256135, -0.090108551, -0.036612425, -0.113223538, -0.153768227, -0.13075842, -0.066075474, -0.129493013, -0.0539637394, -0.0388106443, -0.158874199, -0.0255966205, -0.13540563, -0.0318158343, -0.0988952666, -0.110548653, -0.160866946, -0.043045409, -0.114123404]]> : tensor<1x32xf32>
      %2035 = flow.dispatch.tensor.load %arg3, offsets = [0, 0, 0, 0], sizes = [1, 96, 92, 64], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<1x96x92x64xf32>> -> tensor<1x96x92x64xf32>
      %2036 = flow.dispatch.tensor.load %arg4, offsets = [0, 0, 0, 0], sizes = [32, 96, 3, 3], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<32x96x3x3xf32>> -> tensor<32x96x3x3xf32>
      %2037 = tensor.empty() : tensor<1x32x90x62xf32>
      %2038 = linalg.fill ins(%cst_351 : f32) outs(%2037 : tensor<1x32x90x62xf32>) -> tensor<1x32x90x62xf32>
      %2039 = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%2035, %2036 : tensor<1x96x92x64xf32>, tensor<32x96x3x3xf32>) outs(%2038 : tensor<1x32x90x62xf32>) -> tensor<1x32x90x62xf32>
      %2040 = linalg.generic {indexing_maps = [#map, #map1, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%2039, %cst_353 : tensor<1x32x90x62xf32>, tensor<1x32xf32>) outs(%2037 : tensor<1x32x90x62xf32>) {
      ^bb0(%in: f32, %in_354: f32, %out: f32):
        %2041 = arith.addf %in, %in_354 : f32
        %2042 = arith.cmpf ugt, %2041, %cst_351 : f32
        %2043 = arith.select %2042, %2041, %cst_351 : f32
        %2044 = arith.select %2042, %cst_351, %2041 : f32
        %2045 = arith.mulf %2044, %cst_352 : f32
        %2046 = arith.addf %2043, %2045 : f32
        linalg.yield %2046 : f32
      } -> tensor<1x32x90x62xf32>
      flow.dispatch.tensor.store %2040, %arg5, offsets = [0, 0, 0, 0], sizes = [1, 32, 90, 62], strides = [1, 1, 1, 1] : tensor<1x32x90x62xf32> -> !flow.dispatch.tensor<writeonly:tensor<1x32x90x62xf32>>
      flow.return
    } count(%arg3: index, %arg4: index, %arg5: index, %arg6: index) -> (index, index, index) {
      %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg3, %arg4, %arg5, %arg6
      flow.return %x, %y, %z : index, index, index
    }
    %1532 = flow.tensor.update %1529, %1531[%c0, %c96, %c0, %c0] : tensor<1x32x90x62xf32> -> %1531 as tensor<1x128x90x62xf32>
    %1538 = flow.tensor.update %1529, %1537[%c0, %c96, %c0, %c0] : tensor<1x32x90x62xf32> -> %1537 as tensor<1x160x90x62xf32>
    %1545 = flow.tensor.update %1529, %1544[%c0, %c96, %c0, %c0] : tensor<1x32x90x62xf32> -> %1544 as tensor<1x192x90x62xf32>

->

    %1529:3 = flow.dispatch.workgroups[%c1, %c32, %c90, %c62](%1528, %cst_262) : (tensor<1x96x92x64xf32>, tensor<32x96x3x3xf32>) -> (tensor<1x32x90x62xf32>, tensor<1x32x90x62xf32>, tensor<1x32x90x62xf32>) =
        (%arg3: !flow.dispatch.tensor<readonly:tensor<1x96x92x64xf32>>, %arg4: !flow.dispatch.tensor<readonly:tensor<32x96x3x3xf32>>,
        %arg5: !flow.dispatch.tensor<writeonly:tensor<1x32x90x62xf32>>, %arg6: !flow.dispatch.tensor<writeonly:tensor<1x32x90x62xf32>>, %arg7: !flow.dispatch.tensor<writeonly:tensor<1x32x90x62xf32>>) {
      %cst_351 = arith.constant 0.000000e+00 : f32
      %cst_352 = arith.constant 0.199999988 : f32
      %cst_353 = arith.constant dense<[[-0.104562163, -0.0794978737, -0.139019474, -0.0678559243, -0.0663451776, -0.0395833179, -0.139937162, -0.0967885255, -0.119102009, -0.138187289, -0.0833081305, -0.106967404, 0.0852515175, -0.0525256135, -0.090108551, -0.036612425, -0.113223538, -0.153768227, -0.13075842, -0.066075474, -0.129493013, -0.0539637394, -0.0388106443, -0.158874199, -0.0255966205, -0.13540563, -0.0318158343, -0.0988952666, -0.110548653, -0.160866946, -0.043045409, -0.114123404]]> : tensor<1x32xf32>
      %2035 = flow.dispatch.tensor.load %arg3, offsets = [0, 0, 0, 0], sizes = [1, 96, 92, 64], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<1x96x92x64xf32>> -> tensor<1x96x92x64xf32>
      %2036 = flow.dispatch.tensor.load %arg4, offsets = [0, 0, 0, 0], sizes = [32, 96, 3, 3], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<32x96x3x3xf32>> -> tensor<32x96x3x3xf32>
      %2037 = tensor.empty() : tensor<1x32x90x62xf32>
      %2038 = linalg.fill ins(%cst_351 : f32) outs(%2037 : tensor<1x32x90x62xf32>) -> tensor<1x32x90x62xf32>
      %2039 = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%2035, %2036 : tensor<1x96x92x64xf32>, tensor<32x96x3x3xf32>) outs(%2038 : tensor<1x32x90x62xf32>) -> tensor<1x32x90x62xf32>
      %2040 = linalg.generic {indexing_maps = [#map, #map1, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%2039, %cst_353 : tensor<1x32x90x62xf32>, tensor<1x32xf32>) outs(%2037 : tensor<1x32x90x62xf32>) {
      ^bb0(%in: f32, %in_354: f32, %out: f32):
        %2041 = arith.addf %in, %in_354 : f32
        %2042 = arith.cmpf ugt, %2041, %cst_351 : f32
        %2043 = arith.select %2042, %2041, %cst_351 : f32
        %2044 = arith.select %2042, %cst_351, %2041 : f32
        %2045 = arith.mulf %2044, %cst_352 : f32
        %2046 = arith.addf %2043, %2045 : f32
        linalg.yield %2046 : f32
      } -> tensor<1x32x90x62xf32>
      flow.dispatch.tensor.store %2040, %arg5, offsets = [0, 0, 0, 0], sizes = [1, 32, 90, 62], strides = [1, 1, 1, 1] : tensor<1x32x90x62xf32> -> !flow.dispatch.tensor<writeonly:tensor<1x32x90x62xf32>>
      flow.dispatch.tensor.store %2040, %arg6, offsets = [0, 0, 0, 0], sizes = [1, 32, 90, 62], strides = [1, 1, 1, 1] : tensor<1x32x90x62xf32> -> !flow.dispatch.tensor<writeonly:tensor<1x32x90x62xf32>>
      flow.dispatch.tensor.store %2040, %arg7, offsets = [0, 0, 0, 0], sizes = [1, 32, 90, 62], strides = [1, 1, 1, 1] : tensor<1x32x90x62xf32> -> !flow.dispatch.tensor<writeonly:tensor<1x32x90x62xf32>>
      flow.return
    } count(%arg3: index, %arg4: index, %arg5: index, %arg6: index) -> (index, index, index) {
      %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg3, %arg4, %arg5, %arg6
      flow.return %x, %y, %z : index, index, index
    }
    %1532 = flow.tensor.update %1529#0, %1531[%c0, %c96, %c0, %c0] : tensor<1x32x90x62xf32> -> %1531 as tensor<1x128x90x62xf32>
    %1538 = flow.tensor.update %1529#1, %1537[%c0, %c96, %c0, %c0] : tensor<1x32x90x62xf32> -> %1537 as tensor<1x160x90x62xf32>
    %1545 = flow.tensor.update %1529#2, %1544[%c0, %c96, %c0, %c0] : tensor<1x32x90x62xf32> -> %1544 as tensor<1x192x90x62xf32>

Like this the flow.tensor.update -> in-place storage of write-only outputs would be able to place all the allocations.