iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.86k stars 624 forks source link

Compiling ASR frontend with VMVX causes 'util.buffer.store' complex operand error #11002

Open phoenix-meadowlark opened 2 years ago

phoenix-meadowlark commented 2 years ago

What happened?

Compiling the ASR frontend with VMVX causes the following error:

/tmp/iree/libri/compute_frontend.mlir:89:10: error:
  'util.buffer.store' op operand #0 must be index or integer or floating-point, but got 'complex<f32>'
    %0 = "mhlo.fft"(%arg0) {fft_length = dense<256> : tensor<1xi64>, fft_type = #mhlo<fft_type RFFT>} : (tensor<1x10x1x256xf32>) -> tensor<1x10x1x129xcomplex<f32>>
         ^

(This has a lower priority than the other issues I've been posting since it only affects VMVX).

Steps to reproduce your issue

Simplified MLIR (no error is raised if the mhlo.convolution is elided in the python):

module @jit_compute_frontend {
  func.func public @main(%arg0: tensor<1x1600x1xf32>) -> tensor<1x10x129x1xf32> {
    %0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<160x160xi32>
    %1 = mhlo.constant dense<0> : tensor<i32>
    %2 = "mhlo.broadcast_in_dim"(%1) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i32>) -> tensor<160x160xi32>
    %3 = mhlo.add %0, %2 : tensor<160x160xi32>
    %4 = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<160x160xi32>
    %5 = mhlo.compare  EQ, %3, %4,  SIGNED : (tensor<160x160xi32>, tensor<160x160xi32>) -> tensor<160x160xi1>
    %6 = mhlo.convert %5 : (tensor<160x160xi1>) -> tensor<160x160xf32>
    %7 = mhlo.reshape %6 : (tensor<160x160xf32>) -> tensor<160x1x160xf32>
    %8 = mhlo.reshape %7 : (tensor<160x1x160xf32>) -> tensor<1x160x1x1x1x160xf32>
    %9 = mhlo.reshape %8 : (tensor<1x160x1x1x1x160xf32>) -> tensor<160x1x160xf32>
    %10 = mhlo.convolution(%arg0, %9) dim_numbers = [b, 0, f]x[o, i, 0]->[b, 0, f], window = {stride = [160], pad = [[0, 0]], lhs_dilate = [1], rhs_dilate = [1], reverse = [0]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>]} : (tensor<1x1600x1xf32>, tensor<160x1x160xf32>) -> tensor<1x10x160xf32>
    %11 = mhlo.reshape %10 : (tensor<1x10x160xf32>) -> tensor<1x10x1x160xf32>
    %12 = "mhlo.transpose"(%11) {permutation = dense<[0, 1, 3, 2]> : tensor<4xi64>} : (tensor<1x10x1x160xf32>) -> tensor<1x10x160x1xf32>
    %13 = "mhlo.transpose"(%12) {permutation = dense<[0, 1, 3, 2]> : tensor<4xi64>} : (tensor<1x10x160x1xf32>) -> tensor<1x10x1x160xf32>
    %14 = mhlo.constant dense<0> : tensor<i32>
    %15 = call @_pad(%13, %14) : (tensor<1x10x1x160xf32>, tensor<i32>) -> tensor<1x10x1x256xf32>
    %16 = call @fft(%15) : (tensor<1x10x1x256xf32>) -> tensor<1x10x1x129xcomplex<f32>>
    %17 = "mhlo.transpose"(%16) {permutation = dense<[0, 1, 3, 2]> : tensor<4xi64>} : (tensor<1x10x1x129xcomplex<f32>>) -> tensor<1x10x129x1xcomplex<f32>>
    %18 = mhlo.abs %17 : (tensor<1x10x129x1xcomplex<f32>>) -> tensor<1x10x129x1xf32>
    return %18 : tensor<1x10x129x1xf32>
  }
  func.func private @_pad(%arg0: tensor<1x10x1x160xf32>, %arg1: tensor<i32>) -> tensor<1x10x1x256xf32> {
    %0 = "mhlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i32>) -> tensor<4x2xi32>
    %1 = mhlo.convert %0 : (tensor<4x2xi32>) -> tensor<4x2xf32>
    %2 = mhlo.constant dense<0> : tensor<i32>
    %3 = "mhlo.broadcast_in_dim"(%2) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i32>) -> tensor<1xi32>
    %4 = mhlo.constant dense<0> : tensor<i32>
    %5 = "mhlo.broadcast_in_dim"(%4) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i32>) -> tensor<1xi32>
    %6 = "mhlo.concatenate"(%3, %5) {dimension = 0 : i64} : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
    %7 = "mhlo.gather"(%1, %6) {dimension_numbers = #mhlo.gather<collapsed_slice_dims = [0, 1], start_index_map = [0, 1]>, indices_are_sorted = true, slice_sizes = dense<1> : tensor<2xi64>} : (tensor<4x2xf32>, tensor<2xi32>) -> tensor<f32>
    %8 = "mhlo.pad"(%arg0, %7) {edge_padding_high = dense<0> : tensor<4xi64>, edge_padding_low = dense<0> : tensor<4xi64>, interior_padding = dense<0> : tensor<4xi64>} : (tensor<1x10x1x160xf32>, tensor<f32>) -> tensor<1x10x1x160xf32>
    %9 = mhlo.constant dense<0> : tensor<i32>
    %10 = "mhlo.broadcast_in_dim"(%9) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i32>) -> tensor<1xi32>
    %11 = mhlo.constant dense<1> : tensor<i32>
    %12 = "mhlo.broadcast_in_dim"(%11) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i32>) -> tensor<1xi32>
    %13 = "mhlo.concatenate"(%10, %12) {dimension = 0 : i64} : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
    %14 = "mhlo.gather"(%1, %13) {dimension_numbers = #mhlo.gather<collapsed_slice_dims = [0, 1], start_index_map = [0, 1]>, indices_are_sorted = true, slice_sizes = dense<1> : tensor<2xi64>} : (tensor<4x2xf32>, tensor<2xi32>) -> tensor<f32>
    %15 = "mhlo.pad"(%8, %14) {edge_padding_high = dense<0> : tensor<4xi64>, edge_padding_low = dense<0> : tensor<4xi64>, interior_padding = dense<0> : tensor<4xi64>} : (tensor<1x10x1x160xf32>, tensor<f32>) -> tensor<1x10x1x160xf32>
    %16 = mhlo.constant dense<1> : tensor<i32>
    %17 = "mhlo.broadcast_in_dim"(%16) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i32>) -> tensor<1xi32>
    %18 = mhlo.constant dense<0> : tensor<i32>
    %19 = "mhlo.broadcast_in_dim"(%18) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i32>) -> tensor<1xi32>
    %20 = "mhlo.concatenate"(%17, %19) {dimension = 0 : i64} : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
    %21 = "mhlo.gather"(%1, %20) {dimension_numbers = #mhlo.gather<collapsed_slice_dims = [0, 1], start_index_map = [0, 1]>, indices_are_sorted = true, slice_sizes = dense<1> : tensor<2xi64>} : (tensor<4x2xf32>, tensor<2xi32>) -> tensor<f32>
    %22 = "mhlo.pad"(%15, %21) {edge_padding_high = dense<0> : tensor<4xi64>, edge_padding_low = dense<0> : tensor<4xi64>, interior_padding = dense<0> : tensor<4xi64>} : (tensor<1x10x1x160xf32>, tensor<f32>) -> tensor<1x10x1x160xf32>
    %23 = mhlo.constant dense<1> : tensor<i32>
    %24 = "mhlo.broadcast_in_dim"(%23) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i32>) -> tensor<1xi32>
    %25 = mhlo.constant dense<1> : tensor<i32>
    %26 = "mhlo.broadcast_in_dim"(%25) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i32>) -> tensor<1xi32>
    %27 = "mhlo.concatenate"(%24, %26) {dimension = 0 : i64} : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
    %28 = "mhlo.gather"(%1, %27) {dimension_numbers = #mhlo.gather<collapsed_slice_dims = [0, 1], start_index_map = [0, 1]>, indices_are_sorted = true, slice_sizes = dense<1> : tensor<2xi64>} : (tensor<4x2xf32>, tensor<2xi32>) -> tensor<f32>
    %29 = "mhlo.pad"(%22, %28) {edge_padding_high = dense<0> : tensor<4xi64>, edge_padding_low = dense<0> : tensor<4xi64>, interior_padding = dense<0> : tensor<4xi64>} : (tensor<1x10x1x160xf32>, tensor<f32>) -> tensor<1x10x1x160xf32>
    %30 = mhlo.constant dense<2> : tensor<i32>
    %31 = "mhlo.broadcast_in_dim"(%30) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i32>) -> tensor<1xi32>
    %32 = mhlo.constant dense<0> : tensor<i32>
    %33 = "mhlo.broadcast_in_dim"(%32) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i32>) -> tensor<1xi32>
    %34 = "mhlo.concatenate"(%31, %33) {dimension = 0 : i64} : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
    %35 = "mhlo.gather"(%1, %34) {dimension_numbers = #mhlo.gather<collapsed_slice_dims = [0, 1], start_index_map = [0, 1]>, indices_are_sorted = true, slice_sizes = dense<1> : tensor<2xi64>} : (tensor<4x2xf32>, tensor<2xi32>) -> tensor<f32>
    %36 = "mhlo.pad"(%29, %35) {edge_padding_high = dense<0> : tensor<4xi64>, edge_padding_low = dense<0> : tensor<4xi64>, interior_padding = dense<0> : tensor<4xi64>} : (tensor<1x10x1x160xf32>, tensor<f32>) -> tensor<1x10x1x160xf32>
    %37 = mhlo.constant dense<2> : tensor<i32>
    %38 = "mhlo.broadcast_in_dim"(%37) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i32>) -> tensor<1xi32>
    %39 = mhlo.constant dense<1> : tensor<i32>
    %40 = "mhlo.broadcast_in_dim"(%39) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i32>) -> tensor<1xi32>
    %41 = "mhlo.concatenate"(%38, %40) {dimension = 0 : i64} : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
    %42 = "mhlo.gather"(%1, %41) {dimension_numbers = #mhlo.gather<collapsed_slice_dims = [0, 1], start_index_map = [0, 1]>, indices_are_sorted = true, slice_sizes = dense<1> : tensor<2xi64>} : (tensor<4x2xf32>, tensor<2xi32>) -> tensor<f32>
    %43 = "mhlo.pad"(%36, %42) {edge_padding_high = dense<0> : tensor<4xi64>, edge_padding_low = dense<0> : tensor<4xi64>, interior_padding = dense<0> : tensor<4xi64>} : (tensor<1x10x1x160xf32>, tensor<f32>) -> tensor<1x10x1x160xf32>
    %44 = mhlo.constant dense<3> : tensor<i32>
    %45 = "mhlo.broadcast_in_dim"(%44) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i32>) -> tensor<1xi32>
    %46 = mhlo.constant dense<0> : tensor<i32>
    %47 = "mhlo.broadcast_in_dim"(%46) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i32>) -> tensor<1xi32>
    %48 = "mhlo.concatenate"(%45, %47) {dimension = 0 : i64} : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
    %49 = "mhlo.gather"(%1, %48) {dimension_numbers = #mhlo.gather<collapsed_slice_dims = [0, 1], start_index_map = [0, 1]>, indices_are_sorted = true, slice_sizes = dense<1> : tensor<2xi64>} : (tensor<4x2xf32>, tensor<2xi32>) -> tensor<f32>
    %50 = "mhlo.pad"(%43, %49) {edge_padding_high = dense<0> : tensor<4xi64>, edge_padding_low = dense<0> : tensor<4xi64>, interior_padding = dense<0> : tensor<4xi64>} : (tensor<1x10x1x160xf32>, tensor<f32>) -> tensor<1x10x1x160xf32>
    %51 = mhlo.constant dense<3> : tensor<i32>
    %52 = "mhlo.broadcast_in_dim"(%51) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i32>) -> tensor<1xi32>
    %53 = mhlo.constant dense<1> : tensor<i32>
    %54 = "mhlo.broadcast_in_dim"(%53) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i32>) -> tensor<1xi32>
    %55 = "mhlo.concatenate"(%52, %54) {dimension = 0 : i64} : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
    %56 = "mhlo.gather"(%1, %55) {dimension_numbers = #mhlo.gather<collapsed_slice_dims = [0, 1], start_index_map = [0, 1]>, indices_are_sorted = true, slice_sizes = dense<1> : tensor<2xi64>} : (tensor<4x2xf32>, tensor<2xi32>) -> tensor<f32>
    %57 = "mhlo.pad"(%50, %56) {edge_padding_high = dense<[0, 0, 0, 96]> : tensor<4xi64>, edge_padding_low = dense<0> : tensor<4xi64>, interior_padding = dense<0> : tensor<4xi64>} : (tensor<1x10x1x160xf32>, tensor<f32>) -> tensor<1x10x1x256xf32>
    return %57 : tensor<1x10x1x256xf32>
  }
  func.func private @fft(%arg0: tensor<1x10x1x256xf32>) -> tensor<1x10x1x129xcomplex<f32>> {
    %0 = "mhlo.fft"(%arg0) {fft_length = dense<256> : tensor<1xi64>, fft_type = #mhlo<fft_type RFFT>} : (tensor<1x10x1x256xf32>) -> tensor<1x10x1x129xcomplex<f32>>
    return %0 : tensor<1x10x1x129xcomplex<f32>>
  }
}
iree-compile \
  --iree-hal-target-backends=vmvx \
  --iree-input-type=mhlo \
  /tmp/compute_frontend.mlir \
  -o /tmp/compute_frontend.vmfb
/tmp/iree/libri/compute_frontend.mlir:89:10: error: 'util.buffer.store' op operand #0 must be index or integer or floating-point, but got 'complex<f32>'
    %0 = "mhlo.fft"(%arg0) {fft_length = dense<256> : tensor<1xi64>, fft_type = #mhlo<fft_type RFFT>} : (tensor<1x10x1x256xf32>) -> tensor<1x10x1x129xcomplex<f32>>
         ^
/tmp/iree/libri/compute_frontend.mlir:20:11: note: called from
    %16 = "func.call"(%15) {callee = @fft} : (tensor<1x10x1x256xf32>) -> tensor<1x10x1x129xcomplex<f32>>
          ^
/tmp/iree/libri/compute_frontend.mlir:89:10: note: see current operation: "util.buffer.store"(%36, %23, %22, %39) : (complex<f32>, !util.buffer, index, index) -> ()
    %0 = "mhlo.fft"(%arg0) {fft_length = dense<256> : tensor<1xi64>, fft_type = #mhlo<fft_type RFFT>} : (tensor<1x10x1x256xf32>) -> tensor<1x10x1x129xcomplex<f32>>
         ^
/tmp/iree/libri/compute_frontend.mlir:89:10: error: failed to run translation of source executable to target executable for backend #hal.executable.target<"vmvx", "vmvx-bytecode-fb">
    %0 = "mhlo.fft"(%arg0) {fft_length = dense<256> : tensor<1xi64>, fft_type = #mhlo<fft_type RFFT>} : (tensor<1x10x1x256xf32>) -> tensor<1x10x1x129xcomplex<f32>>
         ^
/tmp/iree/libri/compute_frontend.mlir:20:11: note: called from
    %16 = "func.call"(%15) {callee = @fft} : (tensor<1x10x1x256xf32>) -> tensor<1x10x1x129xcomplex<f32>>
          ^
/tmp/iree/libri/compute_frontend.mlir:89:10: note: see current operation:
"hal.executable.variant"() ({
  "hal.executable.export"() ({
  ^bb0(%arg0: !hal.device, %arg1: index, %arg2: index):
    %0 = "arith.constant"() {value = 3 : index} : () -> index
    %1 = "arith.constant"() {value = 5 : index} : () -> index
    %2 = "arith.constant"() {value = 1 : index} : () -> index
    "hal.return"(%0, %1, %2) : (index, index, index) -> ()
  }) {layout = #hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer>]>]>, ordinal = 0 : index, sym_name = "_main_dispatch_13_generic_10x129", translation_info = #iree_codegen.translation_info<VMVXDefault>} : () -> ()
  "builtin.module"() ({
    "func.func"() ({
    ^bb0(%arg0: !util.buffer, %arg1: !util.buffer, %arg2: !util.list<!util.buffer>, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32):
      %0 = "arith.constant"() {value = 0 : index} : () -> index
      %1 = "arith.constant"() {value = 1290 : index} : () -> index
      %2 = "arith.constant"() {value = 2 : index} : () -> index
      %3 = "arith.constant"() {value = 1 : index} : () -> index
      %4 = "arith.constant"() {value = 43 : index} : () -> index
      %5 = "arith.constant"() {value = 10240 : index} : () -> index
      %6 = "arith.constant"() {value = 20480 : index} : () -> index
      %7 = "arith.constant"() {value = 2560 : index} : () -> index
      %8 = "arith.constant"() {value = 0 : index} : () -> index
      %9 = "util.list.get"(%arg2, %8) : (!util.list<!util.buffer>, index) -> !util.buffer
      %10 = "arith.constant"() {value = 0 : index} : () -> index
      %11 = "util.list.get"(%arg2, %10) : (!util.list<!util.buffer>, index) -> !util.buffer
      %12 = "util.buffer.size"(%11) : (!util.buffer) -> index
      %13 = "arith.constant"() {value = 4 : index} : () -> index
      %14 = "arith.constant"() {value = 10240 : index} : () -> index
      %15 = "util.buffer.subspan"(%11, %12, %5, %14) : (!util.buffer, index, index, index) -> !util.buffer
      %16 = "arith.constant"() {value = 0 : index} : () -> index
      %17 = "util.list.get"(%arg2, %16) : (!util.list<!util.buffer>, index) -> !util.buffer
      %18 = "arith.constant"() {value = 1 : index} : () -> index
      %19 = "util.list.get"(%arg2, %18) : (!util.list<!util.buffer>, index) -> !util.buffer
      %20 = "util.buffer.size"(%19) : (!util.buffer) -> index
      %21 = "util.sizeof"() {sizedType = complex<f32>} : () -> index
      %22 = "arith.muli"(%21, %1) : (index, index) -> index
      %23 = "util.buffer.subspan"(%19, %20, %6, %22) : (!util.buffer, index, index, index) -> !util.buffer
      %24 = "arith.index_cast"(%arg3) : (i32) -> index
      %25 = "arith.index_cast"(%arg4) : (i32) -> index
      "scf.for"(%0, %2, %3) ({
      ^bb0(%arg12: index):
        "scf.for"(%0, %4, %3) ({
        ^bb0(%arg13: index):
          %26 = "affine.apply"(%arg12, %arg13, %25, %24) {map = affine_map<(d0, d1)[s0, s1] -> (d0 * 256 + d1 + s0 * 512 + s1 * 43)>} : (index, index, index, index) -> index
          %27 = "util.buffer.size"(%9) : (!util.buffer) -> index
          %28 = "arith.constant"() {value = 4 : index} : () -> index
          %29 = "arith.muli"(%28, %26) : (index, index) -> index
          %30 = "util.buffer.load"(%9, %27, %29) : (!util.buffer, index, index) -> f32
          %31 = "affine.apply"(%arg12, %arg13, %25, %24) {map = affine_map<(d0, d1)[s0, s1] -> (d0 * 256 + d1 + s0 * 512 + s1 * 43 + 2560)>} : (index, index, index, index) -> index
          %32 = "util.buffer.size"(%17) : (!util.buffer) -> index
          %33 = "arith.constant"() {value = 4 : index} : () -> index
          %34 = "arith.muli"(%33, %31) : (index, index) -> index
          %35 = "util.buffer.load"(%17, %32, %34) : (!util.buffer, index, index) -> f32
          %36 = "complex.create"(%30, %35) : (f32, f32) -> complex<f32>
          %37 = "affine.apply"(%arg12, %arg13, %25, %24) {map = affine_map<(d0, d1)[s0, s1] -> (d0 * 129 + d1 + s0 * 258 + s1 * 43)>} : (index, index, index, index) -> index
          %38 = "util.sizeof"() {sizedType = complex<f32>} : () -> index
          %39 = "arith.muli"(%38, %37) : (index, index) -> index
          "util.buffer.store"(%36, %23, %22, %39) : (complex<f32>, !util.buffer, index, index) -> ()
          "scf.yield"() : () -> ()
        }) : (index, index, index) -> ()
        "scf.yield"() : () -> ()
      }) : (index, index, index) -> ()
      "func.return"() : () -> ()
    }) {function_type = (!util.buffer, !util.buffer, !util.list<!util.buffer>, i32, i32, i32, i32, i32, i32, i32, i32, i32) -> (), sym_name = "_main_dispatch_13_generic_10x129"} : () -> ()
  }) : () -> ()
  "hal.executable.variant_end"() : () -> ()
}) {sym_name = "vmvx_bytecode_fb", target = #hal.executable.target<"vmvx", "vmvx-bytecode-fb">} : () -> ()

... the rest seemed fairly redundant

What component(s) does this issue relate to?

Compiler

Version information

27960a3246e41acfa79b1101f625fc0a42b404ed

Additional context

No response

benvanik commented 2 years ago

Cool, this is probably the first complex usage down to this layer. It doesn't look like there's a transform that does a complex tensor -> flattened 2xf32 tensor yet, just an mhlo-level at compiler/src/iree/compiler/InputConversion/MHLO/ConvertComplexToReal.cpp. This would need to run through pre-util.buffer IR during codegen to insert the x2 dimension as once we've type-erased with the !util.buffer we can't change the indexing like that (element N becomes N*2 but only in the original interpretation of the tensor).