iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.56k stars 571 forks source link

Suboptimal non-outer-dim tensor concat handling in dispatch region formation. #10836

Open benvanik opened 1 year ago

benvanik commented 1 year ago

Was noticing that some concats are getting turned into very inefficient sequences of operations today (effectively gathers that get split into one dispatch per gathered input). It'd be good to survey models in the benchmark suite to see what percentage of dispatches are from concats and how they end up getting lowered. I suspect models with FFTs/complex numbers/image data may be most impacted as that's usually where interleaving happens. We do properly handle outer dimension concats (they get turned into flow.tensor.update ops).

I wanted to get this tracked as it's something that won't stand out as a single slow dispatch in a % breakdown but instead add to the total dispatch count and latency. Models that have high latency, low utilization, and high dispatch counts should check to see if this is a cause. It'll be worse on GPUs where there is higher launch overhead. Maybe transform dialect will solve this - it'd be worth checking to see how these ops lower on that path.

Example from the xla_ops/concatenate.mlir test showing this:

%c0 = util.unfoldable_constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
%c1 = util.unfoldable_constant dense<[[5, 6, 7], [8, 9, 10]]> : tensor<2x3xi32>
%c2 = util.unfoldable_constant dense<[[11, 12], [13, 14]]> : tensor<2x2xi32>
%2 = "mhlo.concatenate"(%c0, %c1, %c2) {dimension = 1} : (tensor<2x2xi32>, tensor<2x3xi32>, tensor<2x2xi32>) -> tensor<2x7xi32>
check.expect_eq_const(%2, dense<[[1, 2, 5, 6, 7, 11, 12], [3, 4, 8, 9, 10, 13, 14]]> : tensor<2x7xi32>) : tensor<2x7xi32>

->

%3 = tensor.empty() : tensor<2x7xi32>
%inserted_slice = tensor.insert_slice %0 into %3[0, 0] [2, 2] [1, 1] : tensor<2x2xi32> into tensor<2x7xi32>
%inserted_slice_3 = tensor.insert_slice %1 into %inserted_slice[0, 2] [2, 3] [1, 1] : tensor<2x3xi32> into tensor<2x7xi32>
%inserted_slice_4 = tensor.insert_slice %2 into %inserted_slice_3[0, 5] [2, 2] [1, 1] : tensor<2x2xi32> into tensor<2x7xi32>

->

  %3 = tensor.empty() : tensor<2x7xi32>
  %4 = flow.dispatch.workgroups[%c2, %c2](%0, %3) : (tensor<2x2xi32>, tensor<2x7xi32>) -> %3 =
      (%arg0: !flow.dispatch.tensor<readonly:2x2xi32>, %arg1: !flow.dispatch.tensor<readwrite:2x7xi32>) {
    %7 = flow.dispatch.tensor.load %arg0, offsets = [0, 0], sizes = [2, 2], strides = [1, 1] : !flow.dispatch.tensor<readonly:2x2xi32> -> tensor<2x2xi32>
    flow.dispatch.tensor.store %7, %arg1, offsets = [0, 0], sizes = [2, 2], strides = [1, 1] : tensor<2x2xi32> -> !flow.dispatch.tensor<readwrite:2x7xi32>
    flow.return
  } count(%arg0: index, %arg1: index) -> (index, index, index) {
    %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0, %arg1
    flow.return %x, %y, %z : index, index, index
  }
  %5 = flow.dispatch.workgroups[%c2, %c3](%1, %4) : (tensor<2x3xi32>, tensor<2x7xi32>) -> %4 =
      (%arg0: !flow.dispatch.tensor<readonly:2x3xi32>, %arg1: !flow.dispatch.tensor<readwrite:2x7xi32>) {
    %7 = flow.dispatch.tensor.load %arg0, offsets = [0, 0], sizes = [2, 3], strides = [1, 1] : !flow.dispatch.tensor<readonly:2x3xi32> -> tensor<2x3xi32>
    flow.dispatch.tensor.store %7, %arg1, offsets = [0, 2], sizes = [2, 3], strides = [1, 1] : tensor<2x3xi32> -> !flow.dispatch.tensor<readwrite:2x7xi32>
    flow.return
  } count(%arg0: index, %arg1: index) -> (index, index, index) {
    %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0, %arg1
    flow.return %x, %y, %z : index, index, index
  }
  %6 = flow.dispatch.workgroups[%c2, %c2](%2, %5) : (tensor<2x2xi32>, tensor<2x7xi32>) -> %5 =
      (%arg0: !flow.dispatch.tensor<readonly:2x2xi32>, %arg1: !flow.dispatch.tensor<readwrite:2x7xi32>) {
    %7 = flow.dispatch.tensor.load %arg0, offsets = [0, 0], sizes = [2, 2], strides = [1, 1] : !flow.dispatch.tensor<readonly:2x2xi32> -> tensor<2x2xi32>
    flow.dispatch.tensor.store %7, %arg1, offsets = [0, 5], sizes = [2, 2], strides = [1, 1] : tensor<2x2xi32> -> !flow.dispatch.tensor<readwrite:2x7xi32>
    flow.return
  } count(%arg0: index, %arg1: index) -> (index, index, index) {
    %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0, %arg1
    flow.return %x, %y, %z : index, index, index
  }

->

    %13 = stream.cmd.execute with(%4 as %arg0: !stream.resource<constant>{%9}, %6 as %arg1: !stream.resource<constant>{%10}, %8 as %arg2: !stream.resource<constant>{%11}, %12 as %arg3: !stream.resource<external>{%c56}) {
      stream.cmd.fill %c0_i8, %arg3[%c0 for %c56] : i8 -> !stream.resource<external>{%c56}
      stream.cmd.dispatch @_xla_concatenate_dispatch_0::@_xla_concatenate_dispatch_0[%c2, %c2] {
        ro %arg0[%c0 for %9] : !stream.resource<constant>{%9},
        rw %arg3[%c0 for %c56] : !stream.resource<external>{%c56}
      }
      stream.cmd.dispatch @_xla_concatenate_dispatch_1::@_xla_concatenate_dispatch_1[%c2, %c3] {
        ro %arg1[%c0 for %10] : !stream.resource<constant>{%10},
        rw %arg3[%c0 for %c56] : !stream.resource<external>{%c56}
      }
      stream.cmd.dispatch @_xla_concatenate_dispatch_2::@_xla_concatenate_dispatch_2[%c2, %c2] {
        ro %arg2[%c0 for %11] : !stream.resource<constant>{%11},
        rw %arg3[%c0 for %c56] : !stream.resource<external>{%c56}
      }
    } => !stream.timepoint

Because of the way the concat is turned into dispatches we end up doing 4 sequential operations (full barriers between) filling the memory with wasted bytes and then writing over it multiple times.

The hope would be that these inserts fuse into producers but they don't today:

  %c0 = util.unfoldable_constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
  %c1 = util.unfoldable_constant dense<[[5, 6, 7], [8, 9, 10]]> : tensor<2x3xi32>
  %c2 = util.unfoldable_constant dense<[[11, 12], [13, 14]]> : tensor<2x2xi32>
  %c0a = "mhlo.add"(%c0, %c0) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
  %c1a = "mhlo.add"(%c1, %c1) : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
  %c2a = "mhlo.add"(%c2, %c2) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
  %2 = "mhlo.concatenate"(%c0a, %c1a, %c2a) {dimension = 1} : (tensor<2x2xi32>, tensor<2x3xi32>, tensor<2x2xi32>) -> tensor<2x7xi32>
  check.expect_eq_const(%2, dense<[[1, 2, 5, 6, 7, 11, 12], [3, 4, 8, 9, 10, 13, 14]]> : tensor<2x7xi32>) : tensor<2x7xi32>

->

    %c0_i32 = arith.constant 0 : i32
    %c2 = arith.constant 2 : index
    %c3 = arith.constant 3 : index
    %cst = arith.constant dense<[[1, 2, 5, 6, 7, 11, 12], [3, 4, 8, 9, 10, 13, 14]]> : tensor<2x7xi32>
    %cst_0 = arith.constant dense<[[11, 12], [13, 14]]> : tensor<2x2xi32>
    %cst_1 = arith.constant dense<[[5, 6, 7], [8, 9, 10]]> : tensor<2x3xi32>
    %cst_2 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
    %0 = util.do_not_optimize(%cst_2) : tensor<2x2xi32>
    %1 = util.do_not_optimize(%cst_1) : tensor<2x3xi32>
    %2 = util.do_not_optimize(%cst_0) : tensor<2x2xi32>
    %3 = flow.dispatch.workgroups[%c2, %c2](%0) : (tensor<2x2xi32>) -> tensor<2x2xi32> =
        (%arg0: !flow.dispatch.tensor<readonly:2x2xi32>, %arg1: !flow.dispatch.tensor<writeonly:2x2xi32>) {
      %10 = flow.dispatch.tensor.load %arg0, offsets = [0, 0], sizes = [2, 2], strides = [1, 1] : !flow.dispatch.tensor<readonly:2x2xi32> -> tensor<2x2xi32>
      %11 = tensor.empty() : tensor<2x2xi32>
      %12 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%10 : tensor<2x2xi32>) outs(%11 : tensor<2x2xi32>) {
      ^bb0(%in: i32, %out: i32):
        %13 = arith.addi %in, %in : i32
        linalg.yield %13 : i32
      } -> tensor<2x2xi32>
      flow.dispatch.tensor.store %12, %arg1, offsets = [0, 0], sizes = [2, 2], strides = [1, 1] : tensor<2x2xi32> -> !flow.dispatch.tensor<writeonly:2x2xi32>
      flow.return
    } count(%arg0: index, %arg1: index) -> (index, index, index) {
      %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0, %arg1
      flow.return %x, %y, %z : index, index, index
    }
    %4 = flow.dispatch.workgroups[%c2, %c3](%1) : (tensor<2x3xi32>) -> tensor<2x3xi32> =
        (%arg0: !flow.dispatch.tensor<readonly:2x3xi32>, %arg1: !flow.dispatch.tensor<writeonly:2x3xi32>) {
      %10 = flow.dispatch.tensor.load %arg0, offsets = [0, 0], sizes = [2, 3], strides = [1, 1] : !flow.dispatch.tensor<readonly:2x3xi32> -> tensor<2x3xi32>
      %11 = tensor.empty() : tensor<2x3xi32>
      %12 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%10 : tensor<2x3xi32>) outs(%11 : tensor<2x3xi32>) {
      ^bb0(%in: i32, %out: i32):
        %13 = arith.addi %in, %in : i32
        linalg.yield %13 : i32
      } -> tensor<2x3xi32>
      flow.dispatch.tensor.store %12, %arg1, offsets = [0, 0], sizes = [2, 3], strides = [1, 1] : tensor<2x3xi32> -> !flow.dispatch.tensor<writeonly:2x3xi32>
      flow.return
    } count(%arg0: index, %arg1: index) -> (index, index, index) {
      %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0, %arg1
      flow.return %x, %y, %z : index, index, index
    }
    %5 = flow.dispatch.workgroups[%c2, %c2](%2) : (tensor<2x2xi32>) -> tensor<2x2xi32> =
        (%arg0: !flow.dispatch.tensor<readonly:2x2xi32>, %arg1: !flow.dispatch.tensor<writeonly:2x2xi32>) {
      %10 = flow.dispatch.tensor.load %arg0, offsets = [0, 0], sizes = [2, 2], strides = [1, 1] : !flow.dispatch.tensor<readonly:2x2xi32> -> tensor<2x2xi32>
      %11 = tensor.empty() : tensor<2x2xi32>
      %12 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%10 : tensor<2x2xi32>) outs(%11 : tensor<2x2xi32>) {
      ^bb0(%in: i32, %out: i32):
        %13 = arith.addi %in, %in : i32
        linalg.yield %13 : i32
      } -> tensor<2x2xi32>
      flow.dispatch.tensor.store %12, %arg1, offsets = [0, 0], sizes = [2, 2], strides = [1, 1] : tensor<2x2xi32> -> !flow.dispatch.tensor<writeonly:2x2xi32>
      flow.return
    } count(%arg0: index, %arg1: index) -> (index, index, index) {
      %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0, %arg1
      flow.return %x, %y, %z : index, index, index
    }
    %6 = flow.tensor.splat %c0_i32 : tensor<2x7xi32>
    %7 = flow.dispatch.workgroups[%c2, %c2](%3, %6) : (tensor<2x2xi32>, tensor<2x7xi32>) -> %6 =
        (%arg0: !flow.dispatch.tensor<readonly:2x2xi32>, %arg1: !flow.dispatch.tensor<readwrite:2x7xi32>) {
      %10 = flow.dispatch.tensor.load %arg0, offsets = [0, 0], sizes = [2, 2], strides = [1, 1] : !flow.dispatch.tensor<readonly:2x2xi32> -> tensor<2x2xi32>
      flow.dispatch.tensor.store %10, %arg1, offsets = [0, 0], sizes = [2, 2], strides = [1, 1] : tensor<2x2xi32> -> !flow.dispatch.tensor<readwrite:2x7xi32>
      flow.return
    } count(%arg0: index, %arg1: index) -> (index, index, index) {
      %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0, %arg1
      flow.return %x, %y, %z : index, index, index
    }
    %8 = flow.dispatch.workgroups[%c2, %c3](%4, %7) : (tensor<2x3xi32>, tensor<2x7xi32>) -> %7 =
        (%arg0: !flow.dispatch.tensor<readonly:2x3xi32>, %arg1: !flow.dispatch.tensor<readwrite:2x7xi32>) {
      %10 = flow.dispatch.tensor.load %arg0, offsets = [0, 0], sizes = [2, 3], strides = [1, 1] : !flow.dispatch.tensor<readonly:2x3xi32> -> tensor<2x3xi32>
      flow.dispatch.tensor.store %10, %arg1, offsets = [0, 2], sizes = [2, 3], strides = [1, 1] : tensor<2x3xi32> -> !flow.dispatch.tensor<readwrite:2x7xi32>
      flow.return
    } count(%arg0: index, %arg1: index) -> (index, index, index) {
      %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0, %arg1
      flow.return %x, %y, %z : index, index, index
    }
    %9 = flow.dispatch.workgroups[%c2, %c2](%5, %8) : (tensor<2x2xi32>, tensor<2x7xi32>) -> %8 =
        (%arg0: !flow.dispatch.tensor<readonly:2x2xi32>, %arg1: !flow.dispatch.tensor<readwrite:2x7xi32>) {
      %10 = flow.dispatch.tensor.load %arg0, offsets = [0, 0], sizes = [2, 2], strides = [1, 1] : !flow.dispatch.tensor<readonly:2x2xi32> -> tensor<2x2xi32>
      flow.dispatch.tensor.store %10, %arg1, offsets = [0, 5], sizes = [2, 2], strides = [1, 1] : tensor<2x2xi32> -> !flow.dispatch.tensor<readwrite:2x7xi32>
      flow.return
    } count(%arg0: index, %arg1: index) -> (index, index, index) {
      %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0, %arg1
      flow.return %x, %y, %z : index, index, index
    }
    check.expect_eq(%9, %cst) : tensor<2x7xi32>

->

    %13 = stream.cmd.execute await(%result_timepoint) => with(%4 as %arg0: !stream.resource<constant>{%9}, %6 as %arg1: !stream.resource<constant>{%10}, %8 as %arg2: !stream.resource<constant>{%11}, %12 as %arg3: !stream.resource<external>{%c56}, %result as %arg4: !stream.resource<transient>{%c128}) {
      stream.cmd.concurrent {
        stream.cmd.dispatch @_xla_concatenate_dispatch_0::@_xla_concatenate_dispatch_0_generic_2x2[%c2, %c2] {
          ro %arg0[%c0 for %9] : !stream.resource<constant>{%9},
          wo %arg4[%c0 for %c128] : !stream.resource<transient>{%c128}
        }
        stream.cmd.fill %c0_i8, %arg3[%c0 for %c56] : i8 -> !stream.resource<external>{%c56}
      }
      stream.cmd.concurrent {
        stream.cmd.dispatch @_xla_concatenate_dispatch_1::@_xla_concatenate_dispatch_1_generic_2x3[%c2, %c3] {
          ro %arg1[%c0 for %10] : !stream.resource<constant>{%10},
          wo %arg4[%c0 for %c128] : !stream.resource<transient>{%c128}
        }
        stream.cmd.dispatch @_xla_concatenate_dispatch_3::@_xla_concatenate_dispatch_3[%c2, %c2] {
          ro %arg4[%c0 for %c128] : !stream.resource<transient>{%c128},
          rw %arg3[%c0 for %c56] : !stream.resource<external>{%c56}
        }
      }
      stream.cmd.concurrent {
        stream.cmd.dispatch @_xla_concatenate_dispatch_0::@_xla_concatenate_dispatch_0_generic_2x2[%c2, %c2] {
          ro %arg2[%c0 for %11] : !stream.resource<constant>{%11},
          wo %arg4[%c0 for %c128] : !stream.resource<transient>{%c128}
        }
        stream.cmd.dispatch @_xla_concatenate_dispatch_4::@_xla_concatenate_dispatch_4[%c2, %c3] {
          ro %arg4[%c0 for %c128] : !stream.resource<transient>{%c128},
          rw %arg3[%c0 for %c56] : !stream.resource<external>{%c56}
        }
      }
      stream.cmd.dispatch @_xla_concatenate_dispatch_5::@_xla_concatenate_dispatch_5[%c2, %c2] {
        ro %arg4[%c0 for %c128] : !stream.resource<transient>{%c128},
        rw %arg3[%c0 for %c56] : !stream.resource<external>{%c56}
      }
    } => !stream.timepoint

So that's 6 dispatches and a fill with marginal concurrency for what should really be a single dispatch. In the past I've seen models that concatenate dozens of values and for us that'd explode to hundreds of really thin memcpy-like dispatches.

This case is a good one for algorithmic improvements at the linalg level: if we added transposes on each concatenated value then transposed back such that we could always be doing non-interleaved concats we could propagate/fuse everything away. That's definitely a kind of optimization that we'll want to do but orthogonal to handling insert slices better.

mariecwhite commented 1 year ago

Some models that include concats at each block:

SqueezeNet: https://tfhub.dev/tensorflow/lite-model/squeezenet/1/default/1 SpaghettiNet: https://tfhub.dev/iree/lite-model/ssd_spaghettinet_edgetpu_large_320/fp32/default/1

What data were you looking to gather? I can look into profiling them with Tracy or providing the mlir files, etc.

benvanik commented 1 year ago

Hmm good question :) Maybe we could do a regex scan over the mhlo/tosa intermediates - thinking that since we're looking for non-0 concat dimensions these should show whether a model in intermediate form is likely to trigger these cases: tosa\.concat.+axis = [1-9]+ mhlo\.concatenate.+dimension = [1-9]+

mariecwhite commented 1 year ago

For SpaghettiNet, we have:

  %312 = "tosa.concat"(%305, %308, %311) {axis = 3 : i64} : (tensor<1x80x80x64xf32>, tensor<1x80x80x64xf32>, tensor<1x80x80x64xf32>) -> tensor<1x80x80x192xf32>
  %326 = "tosa.concat"(%316, %319, %322, %325) {axis = 3 : i64} : (tensor<1x40x40x96xf32>, tensor<1x40x40x96xf32>, tensor<1x40x40x96xf32>, tensor<1x40x40x96xf32>) -> tensor<1x40x40x384xf32>
  %340 = "tosa.concat"(%333, %336, %339, %330) {axis = 3 : i64} : (tensor<1x40x40x128xf32>, tensor<1x40x40x128xf32>, tensor<1x40x40x128xf32>, tensor<1x40x40x128xf32>) -> tensor<1x40x40x512xf32>
  %355 = "tosa.concat"(%345, %348, %351, %354) {axis = 3 : i64} : (tensor<1x40x40x128xf32>, tensor<1x40x40x128xf32>, tensor<1x40x40x128xf32>, tensor<1x40x40x128xf32>) -> tensor<1x40x40x512xf32>
  %370 = "tosa.concat"(%360, %363, %366, %369) {axis = 3 : i64} : (tensor<1x40x40x64xf32>, tensor<1x40x40x64xf32>, tensor<1x40x40x64xf32>, tensor<1x40x40x64xf32>) -> tensor<1x40x40x256xf32>
  %658 = "tosa.concat"(%522, %564, %612, %638, %657) {axis = 1 : i64} : (tensor<1x1200x1x4xf32>, tensor<1x600x1x4xf32>, tensor<1x150x1x4xf32>, tensor<1x54x1x4xf32>, tensor<1x54x1x4xf32>) -> tensor<1x2058x1x4xf32>
  %665 = "tosa.concat"(%527, %569, %617, %643, %664) {axis = 1 : i64} : (tensor<1x1200x91xf32>, tensor<1x600x91xf32>, tensor<1x150x91xf32>, tensor<1x54x91xf32>, tensor<1x54x91xf32>) -> tensor<1x2058x91xf32>

For SqueezeNet, we have:

  %61 = "tosa.concat"(%58, %60) {axis = 3 : i64} : (tensor<1x55x55x64xf32>, tensor<1x55x55x64xf32>) -> tensor<1x55x55x128xf32>
  %68 = "tosa.concat"(%65, %67) {axis = 3 : i64} : (tensor<1x55x55x64xf32>, tensor<1x55x55x64xf32>) -> tensor<1x55x55x128xf32>
  %75 = "tosa.concat"(%72, %74) {axis = 3 : i64} : (tensor<1x55x55x128xf32>, tensor<1x55x55x128xf32>) -> tensor<1x55x55x256xf32>
  %83 = "tosa.concat"(%80, %82) {axis = 3 : i64} : (tensor<1x27x27x128xf32>, tensor<1x27x27x128xf32>) -> tensor<1x27x27x256xf32>
  %90 = "tosa.concat"(%87, %89) {axis = 3 : i64} : (tensor<1x27x27x192xf32>, tensor<1x27x27x192xf32>) -> tensor<1x27x27x384xf32>
  %97 = "tosa.concat"(%94, %96) {axis = 3 : i64} : (tensor<1x27x27x192xf32>, tensor<1x27x27x192xf32>) -> tensor<1x27x27x384xf32>
  %104 = "tosa.concat"(%101, %103) {axis = 3 : i64} : (tensor<1x27x27x256xf32>, tensor<1x27x27x256xf32>) -> tensor<1x27x27x512xf32>
  %112 = "tosa.concat"(%109, %111) {axis = 3 : i64} : (tensor<1x13x13x256xf32>, tensor<1x13x13x256xf32>) -> tensor<1x13x13x512xf32>
benvanik commented 1 year ago

Cool - that's really useful - the first one is likely to be 7 memcpys and 22 serialized slow emulated memcpy dispatches instead of 0 + the transient memory required. Probably worth about 100-300us of latency on a GPU.

chongxing commented 1 year ago

Hi, @benvanik

"We do properly handle outer dimension concats (they get turned into flow.tensor.update ops)."

I have tried for out-dim concat, it can successfully turn into flow.tensor.update. And flow.tensor.update will be eventually lowered to hal.command_buffer.copy_buffer, which turned into memcpyD2D in cuda. Explicit dispatch kernel for concat is eliminated, but still implicit kernel memcpyD2D is introduced.

Could we eliminate the copy_buffer? The concat source just reuse the concat destination buffer. Thanks.

my experiment case: func.func @add_concat() { %lhs = util.unfoldable_constant dense<1.0> : tensor<2x2xf32> %rhs = util.unfoldable_constant dense<1.0> : tensor<2x2xf32> %lhs2 = util.unfoldable_constant dense<1.0> : tensor<3x2xf32> %rhs2 = util.unfoldable_constant dense<1.0> : tensor<3x2xf32> %lhs3 = util.unfoldable_constant dense<1.0> : tensor<5x2xf32> %0 = mhlo.add %lhs, %rhs : tensor<2x2xf32> %1 = mhlo.add %lhs2, %rhs2 : tensor<3x2xf32> %2 = "mhlo.concatenate"(%0, %1) {dimension = 0 : i64} : (tensor<2x2xf32>, tensor<3x2xf32>) -> tensor<5x2xf32> %3 = mhlo.add %lhs3, %2 : tensor<5x2xf32> check.expect_almost_eq_const(%3,dense<3.0> : tensor<5x2xf32>): tensor<5x2xf32> return }

benvanik commented 1 year ago

Nice! It's possible to eliminate the copy - it's an old TODO of mine that I should finally finish: #7729 Right now we even serialize those copies so it's pretty much as bad as you can get (well, still better than making a dispatch at least). I'll take a look at it ~next week!

chongxing commented 1 year ago

Nice! It's possible to eliminate the copy - it's an old TODO of mine that I should finally finish: #7729 Right now we even serialize those copies so it's pretty much as bad as you can get (well, still better than making a dispatch at least). I'll take a look at it ~next week!

That's great. Looking forward to your update!

chongxing commented 1 year ago

Hi, @benvanik

I have seen that the copy is eliminated for outer-dim concat. Thanks for your work.

And for inner-dim concat, instead of doing transpose to make it outer-dim (this requires extra transpose kernel), have you thought of directly eliminate outer-dim concat ? Conceptually, it requires that the source op of concat use stride write into the concat target buffer. To achieve this, the target buffer should be described as strided, and codegen use this information to generate memory write. There should be more challenges. Is it feasible?

thanks.