iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.56k stars 572 forks source link

VMVX fails to serialize executable for simple DNN training #10838

Closed phoenix-meadowlark closed 1 year ago

phoenix-meadowlark commented 1 year ago

What happened?

VMVX fails to compile the training step for a simple two-layer DNN model created with IREE-JAX and PAX. Being as simple as it is, this model would need to work in order for IREE to be able to train any other PAX models, including hotword and ASR.

I'm not entirely sure which parts of the compilation error below are most relevant as they all look somewhat generic.

Steps to reproduce your issue

The error can be reproduced by running the following command on hotword_raw_dnn.mlir:

iree-compile \
  --iree-hal-target-backends=vmvx \
  --iree-input-type=mhlo \
  /tmp/hotword_raw_dnn.mlir \
  -o /tmp/hotword_raw_dnn.vmfb
/tmp/iree/hotword/hotword_raw_dnn.mlir:87:22: error: operand #1 does not dominate this use
    %25 = mhlo.reduce(%23 init: %24) across dimensions = [0, 1] : (tensor<100x1xf32>, tensor<f32>) -> tensor<f32>
                     ^
/tmp/iree/hotword/hotword_raw_dnn.mlir:87:22: note: see current operation: %17 = "util.buffer.load"(%7, %10, %16) : (!util.buffer, index, index) -> f32
/tmp/iree/hotword/hotword_raw_dnn.mlir:100:11: note: operand defined here (op in a parent region)
    %30 = mhlo.divide %25, %29 : tensor<f32>
          ^
/tmp/iree/hotword/hotword_raw_dnn.mlir:100:11: error: failed to run translation of source executable to target executable for backend #hal.executable.target<"vmvx", "vmvx-bytecode-fb">
    %30 = mhlo.divide %25, %29 : tensor<f32>
          ^
/tmp/iree/hotword/hotword_raw_dnn.mlir:100:11: note: see current operation:
"hal.executable.variant"() ({
  "hal.executable.export"() ({
  ^bb0(%arg0: !hal.device):
    %0 = "arith.constant"() {value = 1 : index} : () -> index
    "hal.return"(%0, %0, %0) : (index, index, index) -> ()
  }) {layout = #hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer>]>]>, ordinal = 0 : index, sym_name = "jit__train_step$main_dispatch_11_generic_100", translation_info = #iree_codegen.translation_info<VMVXDefault>} : () -> ()
  "builtin.module"() ({
    "func.func"() ({
    ^bb0(%arg0: !util.buffer, %arg1: !util.buffer, %arg2: !util.list<!util.buffer>, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32):
      %0 = "arith.constant"() {value = 64 : index} : () -> index
      %1 = "arith.constant"() {value = 35200 : index} : () -> index
      %2 = "arith.constant"() {value = 4 : index} : () -> index
      %3 = "arith.constant"() {value = 0 : index} : () -> index
      %4 = "arith.constant"() {value = 100 : index} : () -> index
      %5 = "arith.constant"() {value = 1 : index} : () -> index
      %6 = "arith.constant"() {value = 0.000000e+00 : f32} : () -> f32
      %7 = "util.list.get"(%arg2, %3) : (!util.list<!util.buffer>, index) -> !util.buffer
      %8 = "util.list.get"(%arg2, %5) : (!util.list<!util.buffer>, index) -> !util.buffer
      %9 = "util.buffer.size"(%8) : (!util.buffer) -> index
      "util.buffer.store"(%6, %8, %9, %0) : (f32, !util.buffer, index, index) -> ()
      "scf.for"(%3, %4, %5) ({
      ^bb0(%arg12: index):
        %14 = "arith.constant"() {value = 7120 : index} : () -> index
        %15 = "arith.addi"(%arg12, %14) : (index, index) -> index
        %16 = "arith.muli"(%15, %2) : (index, index) -> index
        %17 = "util.buffer.load"(%7, %10, %16) : (!util.buffer, index, index) -> f32
        %18 = "util.buffer.load"(%8, %9, %0) : (!util.buffer, index, index) -> f32
        %19 = "arith.negf"(%17) : (f32) -> f32
        %20 = "arith.addf"(%18, %19) : (f32, f32) -> f32
        "util.buffer.store"(%20, %8, %9, %0) : (f32, !util.buffer, index, index) -> ()
        "scf.yield"() : () -> ()
      }) : (index, index, index) -> ()
      %10 = "util.buffer.size"(%7) : (!util.buffer) -> index
      %11 = "util.buffer.load"(%7, %10, %1) : (!util.buffer, index, index) -> f32
      %12 = "util.buffer.load"(%8, %9, %0) : (!util.buffer, index, index) -> f32
      %13 = "arith.divf"(%12, %11) : (f32, f32) -> f32
      "util.buffer.store"(%13, %8, %9, %0) : (f32, !util.buffer, index, index) -> ()
      "func.return"() : () -> ()
    }) {function_type = (!util.buffer, !util.buffer, !util.list<!util.buffer>, i32, i32, i32, i32, i32, i32, i32, i32, i32) -> (), sym_name = "jit__train_step$main_dispatch_11_generic_100"} : () -> ()
  }) : () -> ()
  "hal.executable.variant_end"() : () -> ()
}) {sym_name = "vmvx_bytecode_fb", target = #hal.executable.target<"vmvx", "vmvx-bytecode-fb">} : () -> ()
/tmp/iree/hotword/hotword_raw_dnn.mlir:100:11: error: failed to serialize executables
    %30 = mhlo.divide %25, %29 : tensor<f32>
          ^
/tmp/iree/hotword/hotword_raw_dnn.mlir:100:11: note: see current operation:
"hal.executable"() ({
  "hal.executable.variant"() ({
    "hal.executable.export"() ({
    ^bb0(%arg0: !hal.device):
      %0 = "arith.constant"() {value = 1 : index} : () -> index
      "hal.return"(%0, %0, %0) : (index, index, index) -> ()
    }) {layout = #hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer>]>]>, ordinal = 0 : index, sym_name = "jit__train_step$main_dispatch_11_generic_100", translation_info = #iree_codegen.translation_info<VMVXDefault>} : () -> ()
    "builtin.module"() ({
      "func.func"() ({
      ^bb0(%arg0: !util.buffer, %arg1: !util.buffer, %arg2: !util.list<!util.buffer>, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32):
        %0 = "arith.constant"() {value = 64 : index} : () -> index
        %1 = "arith.constant"() {value = 35200 : index} : () -> index
        %2 = "arith.constant"() {value = 4 : index} : () -> index
        %3 = "arith.constant"() {value = 0 : index} : () -> index
        %4 = "arith.constant"() {value = 100 : index} : () -> index
        %5 = "arith.constant"() {value = 1 : index} : () -> index
        %6 = "arith.constant"() {value = 0.000000e+00 : f32} : () -> f32
        %7 = "util.list.get"(%arg2, %3) : (!util.list<!util.buffer>, index) -> !util.buffer
        %8 = "util.list.get"(%arg2, %5) : (!util.list<!util.buffer>, index) -> !util.buffer
        %9 = "util.buffer.size"(%8) : (!util.buffer) -> index
        "util.buffer.store"(%6, %8, %9, %0) : (f32, !util.buffer, index, index) -> ()
        "scf.for"(%3, %4, %5) ({
        ^bb0(%arg12: index):
          %14 = "arith.constant"() {value = 7120 : index} : () -> index
          %15 = "arith.addi"(%arg12, %14) : (index, index) -> index
          %16 = "arith.muli"(%15, %2) : (index, index) -> index
          %17 = "util.buffer.load"(%7, %10, %16) : (!util.buffer, index, index) -> f32
          %18 = "util.buffer.load"(%8, %9, %0) : (!util.buffer, index, index) -> f32
          %19 = "arith.negf"(%17) : (f32) -> f32
          %20 = "arith.addf"(%18, %19) : (f32, f32) -> f32
          "util.buffer.store"(%20, %8, %9, %0) : (f32, !util.buffer, index, index) -> ()
          "scf.yield"() : () -> ()
        }) : (index, index, index) -> ()
        %10 = "util.buffer.size"(%7) : (!util.buffer) -> index
        %11 = "util.buffer.load"(%7, %10, %1) : (!util.buffer, index, index) -> f32
        %12 = "util.buffer.load"(%8, %9, %0) : (!util.buffer, index, index) -> f32
        %13 = "arith.divf"(%12, %11) : (f32, f32) -> f32
        "util.buffer.store"(%13, %8, %9, %0) : (f32, !util.buffer, index, index) -> ()
        "func.return"() : () -> ()
      }) {function_type = (!util.buffer, !util.buffer, !util.list<!util.buffer>, i32, i32, i32, i32, i32, i32, i32, i32, i32) -> (), sym_name = "jit__train_step$main_dispatch_11_generic_100"} : () -> ()
    }) : () -> ()
    "hal.executable.variant_end"() : () -> ()
  }) {sym_name = "vmvx_bytecode_fb", target = #hal.executable.target<"vmvx", "vmvx-bytecode-fb">} : () -> ()
  "hal.executable_end"() : () -> ()
}) {sym_name = "jit__train_step$main_dispatch_11", sym_visibility = "private"} : () -> ()
compilation failed

What component(s) does this issue relate to?

MLIR, Compiler

Version information

4d1b61368ce94b254b57caba66b0cca5b3bda032

Additional context

No response

benvanik commented 1 year ago

Looks like the %10 = "util.buffer.size"(%7) : (!util.buffer) -> index is getting put in the wrong place (maybe the wrong builder being used so it's going into the parent region instead of at the insertion point inside of the scf.for)

benvanik commented 1 year ago

wut, looks like ConvertAffineToStandard may be doing this? Before that things look good and the dominance issue shows during it:

// -----// IR Dump After CSE (cse) //----- //
hal.executable.variant public @vmvx_bytecode_fb, target = <"vmvx", "vmvx-bytecode-fb"> {
  hal.executable.export public @jit__train_step$main_dispatch_11_generic_100 ordinal(0) layout(#hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer>]>]>) attributes {translation_info = #iree_codegen.translation_info<VMVXDefault>} {
  ^bb0(%arg0: !hal.device):
    %c1 = arith.constant 1 : index
    hal.return %c1, %c1, %c1 : index, index, index
  }
  builtin.module {
    func.func @jit__train_step$main_dispatch_11_generic_100(%arg0: !util.buffer, %arg1: !util.buffer, %arg2: !util.list<!util.buffer>, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32) {
      %c64 = arith.constant 64 : index
      %c35200 = arith.constant 35200 : index
      %c4 = arith.constant 4 : index
      %c0 = arith.constant 0 : index
      %c100 = arith.constant 100 : index
      %c1 = arith.constant 1 : index
      %cst = arith.constant 0.000000e+00 : f32
      %0 = util.list.get %arg2[%c0] : !util.list<!util.buffer>
      %1 = util.list.get %arg2[%c1] : !util.list<!util.buffer>
      %buffer_size = util.buffer.size %1 : !util.buffer
      util.buffer.store %cst, %1[%c64] : f32 -> !util.buffer{%buffer_size}
      scf.for %arg12 = %c0 to %c100 step %c1 {
        %5 = affine.apply affine_map<()[s0] -> (s0 + 7120)>()[%arg12]
        %buffer_size_1 = util.buffer.size %0 : !util.buffer
        %6 = arith.muli %5, %c4 : index
        %7 = util.buffer.load %0[%6] : !util.buffer{%buffer_size_1} -> f32
        %8 = util.buffer.load %1[%c64] : !util.buffer{%buffer_size} -> f32
        %9 = arith.negf %7 : f32
        %10 = arith.addf %8, %9 : f32
        util.buffer.store %10, %1[%c64] : f32 -> !util.buffer{%buffer_size}
      }
      %buffer_size_0 = util.buffer.size %0 : !util.buffer
      %2 = util.buffer.load %0[%c35200] : !util.buffer{%buffer_size_0} -> f32
      %3 = util.buffer.load %1[%c64] : !util.buffer{%buffer_size} -> f32
      %4 = arith.divf %3, %2 : f32
      util.buffer.store %4, %1[%c64] : f32 -> !util.buffer{%buffer_size}
      return
    }
  }
}

D:\Dev\iree/../iree-tmp/hotword_raw_dnn.mlir:87:22: error: operand #1 does not dominate this use
    %25 = mhlo.reduce(%23 init: %24) across dimensions = [0, 1] : (tensor<100x1xf32>, tensor<f32>) -> tensor<f32>
                     ^
D:\Dev\iree/../iree-tmp/hotword_raw_dnn.mlir:87:22: note: see current operation: %17 = "util.buffer.load"(%7, %10, %16) : (!util.buffer, index, index) -> f32
D:\Dev\iree/../iree-tmp/hotword_raw_dnn.mlir:100:11: note: operand defined here (op in a parent region)
    %30 = mhlo.divide %25, %29 : tensor<f32>
          ^
// -----// IR Dump After ConvertAffineToStandard Failed (lower-affine) //----- //
"func.func"() ({
^bb0(%arg0: !util.buffer, %arg1: !util.buffer, %arg2: !util.list<!util.buffer>, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32):
  %0 = "arith.constant"() {value = 64 : index} : () -> index
  %1 = "arith.constant"() {value = 35200 : index} : () -> index
  %2 = "arith.constant"() {value = 4 : index} : () -> index
  %3 = "arith.constant"() {value = 0 : index} : () -> index
  %4 = "arith.constant"() {value = 100 : index} : () -> index
  %5 = "arith.constant"() {value = 1 : index} : () -> index
  %6 = "arith.constant"() {value = 0.000000e+00 : f32} : () -> f32
  %7 = "util.list.get"(%arg2, %3) : (!util.list<!util.buffer>, index) -> !util.buffer
  %8 = "util.list.get"(%arg2, %5) : (!util.list<!util.buffer>, index) -> !util.buffer
  %9 = "util.buffer.size"(%8) : (!util.buffer) -> index
  "util.buffer.store"(%6, %8, %9, %0) : (f32, !util.buffer, index, index) -> ()
  "scf.for"(%3, %4, %5) ({
  ^bb0(%arg12: index):
    %14 = "arith.constant"() {value = 7120 : index} : () -> index
    %15 = "arith.addi"(%arg12, %14) : (index, index) -> index
    %16 = "arith.muli"(%15, %2) : (index, index) -> index
    %17 = "util.buffer.load"(%7, %10, %16) : (!util.buffer, index, index) -> f32
    %18 = "util.buffer.load"(%8, %9, %0) : (!util.buffer, index, index) -> f32
    %19 = "arith.negf"(%17) : (f32) -> f32
    %20 = "arith.addf"(%18, %19) : (f32, f32) -> f32
    "util.buffer.store"(%20, %8, %9, %0) : (f32, !util.buffer, index, index) -> ()
    "scf.yield"() : () -> ()
  }) : (index, index, index) -> ()
  %10 = "util.buffer.size"(%7) : (!util.buffer) -> index
  %11 = "util.buffer.load"(%7, %10, %1) : (!util.buffer, index, index) -> f32
  %12 = "util.buffer.load"(%8, %9, %0) : (!util.buffer, index, index) -> f32
  %13 = "arith.divf"(%12, %11) : (f32, f32) -> f32
  "util.buffer.store"(%13, %8, %9, %0) : (f32, !util.buffer, index, index) -> ()
  "func.return"() : () -> ()
}) {function_type = (!util.buffer, !util.buffer, !util.list<!util.buffer>, i32, i32, i32, i32, i32, i32, i32, i32, i32) -> (), sym_name = "jit__train_step$main_dispatch_11_generic_100"} : () -> ()
benvanik commented 1 year ago

Issue is in the BufferSizeOp folder - which runs during ConvertAffineToStandard since it is using the conversion framework.

benvanik commented 1 year ago

Repro:

func.func @FoldNestedBufferSizeOp(%list: !util.list<!util.buffer>) {
  %c0 = arith.constant 0 : index
  %c1 = arith.constant 1 : index
  %c128 = arith.constant 128 : index
  %buffer = util.list.get %list[%c0] : !util.list<!util.buffer>
  scf.for %i = %c0 to %c128 step %c1 {
    %buffer_size_inner = util.buffer.size %buffer : !util.buffer
    %inner = util.buffer.load %buffer[%i] : !util.buffer{%buffer_size_inner} -> i8
    util.do_not_optimize(%inner) : i8
  }
  %buffer_size_outer = util.buffer.size %buffer : !util.buffer
  %outer = util.buffer.load %buffer[%c128] : !util.buffer{%buffer_size_outer} -> i8
  util.do_not_optimize(%outer) : i8
  return
}