Closed phoenix-meadowlark closed 1 year ago
Looks like the %10 = "util.buffer.size"(%7) : (!util.buffer) -> index
is getting put in the wrong place (maybe the wrong builder being used so it's going into the parent region instead of at the insertion point inside of the scf.for)
wut, looks like ConvertAffineToStandard may be doing this? Before that things look good and the dominance issue shows during it:
// -----// IR Dump After CSE (cse) //----- //
hal.executable.variant public @vmvx_bytecode_fb, target = <"vmvx", "vmvx-bytecode-fb"> {
hal.executable.export public @jit__train_step$main_dispatch_11_generic_100 ordinal(0) layout(#hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer>]>]>) attributes {translation_info = #iree_codegen.translation_info<VMVXDefault>} {
^bb0(%arg0: !hal.device):
%c1 = arith.constant 1 : index
hal.return %c1, %c1, %c1 : index, index, index
}
builtin.module {
func.func @jit__train_step$main_dispatch_11_generic_100(%arg0: !util.buffer, %arg1: !util.buffer, %arg2: !util.list<!util.buffer>, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32) {
%c64 = arith.constant 64 : index
%c35200 = arith.constant 35200 : index
%c4 = arith.constant 4 : index
%c0 = arith.constant 0 : index
%c100 = arith.constant 100 : index
%c1 = arith.constant 1 : index
%cst = arith.constant 0.000000e+00 : f32
%0 = util.list.get %arg2[%c0] : !util.list<!util.buffer>
%1 = util.list.get %arg2[%c1] : !util.list<!util.buffer>
%buffer_size = util.buffer.size %1 : !util.buffer
util.buffer.store %cst, %1[%c64] : f32 -> !util.buffer{%buffer_size}
scf.for %arg12 = %c0 to %c100 step %c1 {
%5 = affine.apply affine_map<()[s0] -> (s0 + 7120)>()[%arg12]
%buffer_size_1 = util.buffer.size %0 : !util.buffer
%6 = arith.muli %5, %c4 : index
%7 = util.buffer.load %0[%6] : !util.buffer{%buffer_size_1} -> f32
%8 = util.buffer.load %1[%c64] : !util.buffer{%buffer_size} -> f32
%9 = arith.negf %7 : f32
%10 = arith.addf %8, %9 : f32
util.buffer.store %10, %1[%c64] : f32 -> !util.buffer{%buffer_size}
}
%buffer_size_0 = util.buffer.size %0 : !util.buffer
%2 = util.buffer.load %0[%c35200] : !util.buffer{%buffer_size_0} -> f32
%3 = util.buffer.load %1[%c64] : !util.buffer{%buffer_size} -> f32
%4 = arith.divf %3, %2 : f32
util.buffer.store %4, %1[%c64] : f32 -> !util.buffer{%buffer_size}
return
}
}
}
D:\Dev\iree/../iree-tmp/hotword_raw_dnn.mlir:87:22: error: operand #1 does not dominate this use
%25 = mhlo.reduce(%23 init: %24) across dimensions = [0, 1] : (tensor<100x1xf32>, tensor<f32>) -> tensor<f32>
^
D:\Dev\iree/../iree-tmp/hotword_raw_dnn.mlir:87:22: note: see current operation: %17 = "util.buffer.load"(%7, %10, %16) : (!util.buffer, index, index) -> f32
D:\Dev\iree/../iree-tmp/hotword_raw_dnn.mlir:100:11: note: operand defined here (op in a parent region)
%30 = mhlo.divide %25, %29 : tensor<f32>
^
// -----// IR Dump After ConvertAffineToStandard Failed (lower-affine) //----- //
"func.func"() ({
^bb0(%arg0: !util.buffer, %arg1: !util.buffer, %arg2: !util.list<!util.buffer>, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32):
%0 = "arith.constant"() {value = 64 : index} : () -> index
%1 = "arith.constant"() {value = 35200 : index} : () -> index
%2 = "arith.constant"() {value = 4 : index} : () -> index
%3 = "arith.constant"() {value = 0 : index} : () -> index
%4 = "arith.constant"() {value = 100 : index} : () -> index
%5 = "arith.constant"() {value = 1 : index} : () -> index
%6 = "arith.constant"() {value = 0.000000e+00 : f32} : () -> f32
%7 = "util.list.get"(%arg2, %3) : (!util.list<!util.buffer>, index) -> !util.buffer
%8 = "util.list.get"(%arg2, %5) : (!util.list<!util.buffer>, index) -> !util.buffer
%9 = "util.buffer.size"(%8) : (!util.buffer) -> index
"util.buffer.store"(%6, %8, %9, %0) : (f32, !util.buffer, index, index) -> ()
"scf.for"(%3, %4, %5) ({
^bb0(%arg12: index):
%14 = "arith.constant"() {value = 7120 : index} : () -> index
%15 = "arith.addi"(%arg12, %14) : (index, index) -> index
%16 = "arith.muli"(%15, %2) : (index, index) -> index
%17 = "util.buffer.load"(%7, %10, %16) : (!util.buffer, index, index) -> f32
%18 = "util.buffer.load"(%8, %9, %0) : (!util.buffer, index, index) -> f32
%19 = "arith.negf"(%17) : (f32) -> f32
%20 = "arith.addf"(%18, %19) : (f32, f32) -> f32
"util.buffer.store"(%20, %8, %9, %0) : (f32, !util.buffer, index, index) -> ()
"scf.yield"() : () -> ()
}) : (index, index, index) -> ()
%10 = "util.buffer.size"(%7) : (!util.buffer) -> index
%11 = "util.buffer.load"(%7, %10, %1) : (!util.buffer, index, index) -> f32
%12 = "util.buffer.load"(%8, %9, %0) : (!util.buffer, index, index) -> f32
%13 = "arith.divf"(%12, %11) : (f32, f32) -> f32
"util.buffer.store"(%13, %8, %9, %0) : (f32, !util.buffer, index, index) -> ()
"func.return"() : () -> ()
}) {function_type = (!util.buffer, !util.buffer, !util.list<!util.buffer>, i32, i32, i32, i32, i32, i32, i32, i32, i32) -> (), sym_name = "jit__train_step$main_dispatch_11_generic_100"} : () -> ()
Issue is in the BufferSizeOp folder - which runs during ConvertAffineToStandard since it is using the conversion framework.
Repro:
func.func @FoldNestedBufferSizeOp(%list: !util.list<!util.buffer>) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c128 = arith.constant 128 : index
%buffer = util.list.get %list[%c0] : !util.list<!util.buffer>
scf.for %i = %c0 to %c128 step %c1 {
%buffer_size_inner = util.buffer.size %buffer : !util.buffer
%inner = util.buffer.load %buffer[%i] : !util.buffer{%buffer_size_inner} -> i8
util.do_not_optimize(%inner) : i8
}
%buffer_size_outer = util.buffer.size %buffer : !util.buffer
%outer = util.buffer.load %buffer[%c128] : !util.buffer{%buffer_size_outer} -> i8
util.do_not_optimize(%outer) : i8
return
}
What happened?
VMVX fails to compile the training step for a simple two-layer DNN model created with IREE-JAX and PAX. Being as simple as it is, this model would need to work in order for IREE to be able to train any other PAX models, including hotword and ASR.
I'm not entirely sure which parts of the compilation error below are most relevant as they all look somewhat generic.
Steps to reproduce your issue
The error can be reproduced by running the following command on
hotword_raw_dnn.mlir
:What component(s) does this issue relate to?
MLIR, Compiler
Version information
4d1b61368ce94b254b57caba66b0cca5b3bda032
Additional context
No response