Open pdhirajkumarprasad opened 1 month ago
@pashu123 looks like the scf.forall resolution not kicking in somewhere?
@pashu123 looks like the scf.forall resolution not kicking in somewhere?
It's forming a dispatch with dag_root https://gist.github.com/pashu123/12dd8d3771a1a5cfd99a52ee5b001a98#file-module_main_graph-async_dispatch_3-mlir-L5 looking into it.
Try top of tree. Quinn fixes a few builtins
Try top of tree. Quinn fixes a few builtins
I am trying the top of the main. It's not coming from the BuiltIns but forming during the dispatch region creation. Full dump: https://gist.github.com/pashu123/fb9a9d29b9f199d6f10bfb3c2d55ed49 Line 47735
%27 = flow.dispatch.region[%23, %24, %25, %26] -> (tensor<?x?x128x384xi1>{%20, %21}) {
%30 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} outs(%22 : tensor<?x?x128x384xi1>) {
^bb0(%out: i1):
linalg.yield %false : i1
} -> tensor<?x?x128x384xi1>
flow.return %30 : tensor<?x?x128x384xi1>
} count(%arg8: index, %arg9: index, %arg10: index, %arg11: index) -> (index, index, index) {
%x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg8, %arg9, %arg10, %arg11
flow.return %x, %y, %z : index, index, index
}
https://github.com/iree-org/iree/blob/fa752ae1e491a1f8fde8967bf04473c6a6c1ca18/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp#L583 It's forming here.
So, rather than creating count_from_dag_root,
should it create a count_from_slice
here? @MaheshRavishankar
Yes! I wonder why that is happening.
Yes! I wonder why that is happening.
It is choosing count_from_dag_root
here because the size of the tensor is data dependent. Here is the dump before FormDispatchRegions (see the op at the end of the dump):
// -----// IR Dump After FormScalarDispatchesPass (iree-dispatch-creation-form-scalar-dispatches) //----- //
util.func public @main_graph$async(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view, %arg3: !hal.buffer_view, %arg4: !hal.buffer_view, %arg5: !hal.buffer_view, %arg6: !hal.fence, %arg7: !hal.fence) -> !hal.buffer_view attributes {inlining_policy = #util.inline.never, iree.abi.model = "coarse-fences", iree.abi.stub} {
%false = arith.constant false
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c0_i64 = arith.constant 0 : i64
%cst = arith.constant dense<8> : tensor<i64>
%cst_0 = arith.constant dense<0> : tensor<3xi64>
%0 = hal.buffer_view.dim<%arg1 : !hal.buffer_view>[0] : index
%1 = hal.buffer_view.dim<%arg1 : !hal.buffer_view>[1] : index
%2 = hal.tensor.import wait(%arg6) => %arg4 : !hal.buffer_view -> tensor<4xi64>
%3 = hal.tensor.import wait(%arg6) => %arg5 : !hal.buffer_view -> tensor<2xi64>
%extracted_slice = tensor.extract_slice %2[0] [1] [1] : tensor<4xi64> to tensor<i64>
%expanded = tensor.expand_shape %extracted_slice [] output_shape [1] : tensor<i64> into tensor<1xi64>
%extracted = tensor.extract %expanded[%c0] : tensor<1xi64>
%extracted_slice_1 = tensor.extract_slice %2[1] [1] [1] : tensor<4xi64> to tensor<i64>
%expanded_2 = tensor.expand_shape %extracted_slice_1 [] output_shape [1] : tensor<i64> into tensor<1xi64>
%extracted_3 = tensor.extract %expanded_2[%c0] : tensor<1xi64>
%extracted_slice_4 = tensor.extract_slice %2[2] [1] [1] : tensor<4xi64> to tensor<i64>
%expanded_5 = tensor.expand_shape %extracted_slice_4 [] output_shape [1] : tensor<i64> into tensor<1xi64>
%extracted_6 = tensor.extract %expanded_5[%c0] : tensor<1xi64>
%extracted_slice_7 = tensor.extract_slice %2[3] [1] [1] : tensor<4xi64> to tensor<i64>
%expanded_8 = tensor.expand_shape %extracted_slice_7 [] output_shape [1] : tensor<i64> into tensor<1xi64>
%extracted_9 = tensor.extract %expanded_8[%c0] : tensor<1xi64>
%4 = arith.index_cast %extracted : i64 to index
%5 = arith.index_cast %extracted_6 : i64 to index
%6 = arith.index_cast %extracted_3 : i64 to index
%7 = arith.index_cast %extracted_9 : i64 to index
%8 = tensor.empty() : tensor<i64>
%9 = flow.dispatch.region -> (tensor<i64>) {
%26 = linalg.generic {indexing_maps = [affine_map<() -> ()>], iterator_types = []} outs(%8 : tensor<i64>) {
^bb0(%out: i64):
%extracted_20 = tensor.extract %3[%c1] : tensor<2xi64>
%27 = arith.divsi %extracted_20, %c0_i64 : i64
linalg.yield %27 : i64
} -> tensor<i64>
flow.return %26 : tensor<i64>
} count() -> (index, index, index) {
%c1_20 = arith.constant 1 : index
flow.return %c1_20, %c1_20, %c1_20 : index, index, index
}
%inserted_slice = tensor.insert_slice %9 into %cst_0[1] [1] [1] : tensor<i64> into tensor<3xi64>
%inserted_slice_10 = tensor.insert_slice %cst into %inserted_slice[2] [1] [1] : tensor<i64> into tensor<3xi64>
%extracted_slice_11 = tensor.extract_slice %inserted_slice_10[0] [1] [1] : tensor<3xi64> to tensor<i64>
%expanded_12 = tensor.expand_shape %extracted_slice_11 [] output_shape [1] : tensor<i64> into tensor<1xi64>
%extracted_13 = tensor.extract %expanded_12[%c0] : tensor<1xi64>
%10 = arith.cmpi eq, %extracted_13, %c0_i64 : i64
%11 = arith.addi %4, %5 : index
%12 = arith.addi %11, %0 : index
%13 = arith.index_cast %12 : index to i64
%14 = flow.dispatch.region -> (tensor<i64>) {
%26 = linalg.generic {indexing_maps = [affine_map<() -> ()>], iterator_types = []} outs(%8 : tensor<i64>) {
^bb0(%out: i64):
%27 = arith.select %10, %13, %extracted_13 : i64
linalg.yield %27 : i64
} -> tensor<i64>
flow.return %26 : tensor<i64>
} count() -> (index, index, index) {
%c1_20 = arith.constant 1 : index
flow.return %c1_20, %c1_20, %c1_20 : index, index, index
}
%extracted_14 = tensor.extract %14[] : tensor<i64>
%extracted_slice_15 = tensor.extract_slice %inserted_slice_10[1] [1] [1] : tensor<3xi64> to tensor<i64>
%expanded_16 = tensor.expand_shape %extracted_slice_15 [] output_shape [1] : tensor<i64> into tensor<1xi64>
%extracted_17 = tensor.extract %expanded_16[%c0] : tensor<1xi64>
%15 = arith.cmpi eq, %extracted_17, %c0_i64 : i64
%16 = arith.addi %6, %7 : index
%17 = arith.addi %16, %1 : index
%18 = arith.index_cast %17 : index to i64
%19 = flow.dispatch.region -> (tensor<i64>) {
%26 = linalg.generic {indexing_maps = [affine_map<() -> ()>], iterator_types = []} outs(%8 : tensor<i64>) {
^bb0(%out: i64):
%27 = arith.select %15, %18, %extracted_17 : i64
linalg.yield %27 : i64
} -> tensor<i64>
flow.return %26 : tensor<i64>
} count() -> (index, index, index) {
%c1_20 = arith.constant 1 : index
flow.return %c1_20, %c1_20, %c1_20 : index, index, index
}
%extracted_18 = tensor.extract %19[] : tensor<i64>
%20 = arith.index_cast %extracted_14 : i64 to index
%21 = arith.index_cast %extracted_18 : i64 to index
// The init operand here is dependent on `%extracted_14` and `%extracted_18`
%22 = tensor.empty(%20, %21) : tensor<?x?x128x384xi1>
%23 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} outs(%22 : tensor<?x?x128x384xi1>) {
^bb0(%out: i1):
linalg.yield %false : i1
} -> tensor<?x?x128x384xi1>
%24 = hal.tensor.barrier join(%23 : tensor<?x?x128x384xi1>) => %arg7 : !hal.fence
%dim = tensor.dim %24, %c0 : tensor<?x?x128x384xi1>
%dim_19 = tensor.dim %24, %c1 : tensor<?x?x128x384xi1>
%25 = hal.tensor.export %24 : tensor<?x?x128x384xi1>{%dim, %dim_19} -> !hal.buffer_view
util.return %25 : !hal.buffer_view
}
According to the comments in RegionOpUtils.cpp
, it seems this has to stay as count_from_dag_root
. I'm guessing the unrealized conversion casts are coming from
%20 = arith.index_cast %extracted_14 : i64 to index
%21 = arith.index_cast %extracted_18 : i64 to index
I'll look at what is happening with these ops.
@MaheshRavishankar do we need to add a lowering for flow.dispatch.workgroup_count_from_dag_root
?
@MaheshRavishankar do we need to add a lowering for
flow.dispatch.workgroup_count_from_dag_root
?
No, that really cant be handled very well with scf.forall
. But thats not the issue. Hold on.
@pdhirajkumarprasad this is again an issue where none of this has any real computation. This is just index computation for the whole code. We need to get better at triaging this (as a team). I know it fails in codegen, but is not a codegen issue. Labeling as codegen just increases latency.
Not pointing fingers. I didnt see the actual input at all. I just saw the error message too... Just pointing out that we spent two days looking somewhere else :)
@zjgarvey I think this is yours :D
Currently we have following 4 models failing with above error
model--long-t5-tglobal-base-16384-book-summary--pszemraj
model--long-t5-tglobal-base-16384-booksum-V11-big_patent-V2--pszemraj
model--long-t5-tglobal-base-16384-booksum-V12--pszemraj
migraphx_bert__bertsquad-12
%19 = flow.dispatch.region -> (tensor<i64>) {
%26 = linalg.generic {indexing_maps = [affine_map<() -> ()>], iterator_types = []} outs(%8 : tensor<i64>) {
^bb0(%out: i64):
%27 = arith.select %15, %18, %extracted_17 : i64
linalg.yield %27 : i64
} -> tensor<i64>
flow.return %26 : tensor<i64>
} count() -> (index, index, index) {
%c1_20 = arith.constant 1 : index
flow.return %c1_20, %c1_20, %c1_20 : index, index, index
}
%extracted_18 = tensor.extract %19[] : tensor<i64>
This is really bad IR - we should be able to compile it, but this is really bad.
All of that could be %extracted_18 = arith.select %15, %18, %extracted_17 : i64
and run literally 10000 faster.
Good as a test of the compiler, but this should be P1 to fix at whatever level is best (I think linalg/detensorizing should have handled this).
This is really bad IR - we should be able to compile it, but this is really bad. All of that could be
%extracted_18 = arith.select %15, %18, %extracted_17 : i64
and run literally 10000 faster. Good as a test of the compiler, but this should be P1 to fix at whatever level is best (I think linalg/detensorizing should have handled this).
We shouldn't be getting that kind of IR in the full model since those arith.select
ops should get scalarized in the full model.
The problem is that we can't scalarize from a func.func
input, so a bit more IR would be necessary for an accurate reproducer (specifically the producers for %arg4
and %arg5
).
Currently we have following 4 models failing with above error
model--long-t5-tglobal-base-16384-book-summary--pszemraj model--long-t5-tglobal-base-16384-booksum-V11-big_patent-V2--pszemraj model--long-t5-tglobal-base-16384-booksum-V12--pszemraj migraphx_bert__bertsquad-12
At least the model model--long-t5-tglobal-base-16384-book-summary--pszemraj
seems to be failing this conversion i64
to index
on an onnx.Einsum
op, so it might be good to update with a better reproducer.
Here is new reduced IR from model without much modification so that we keep the data flow intact
module {
func.func @tf2onnx(%arg0: !torch.vtensor<[?],si64>, %arg1: !torch.vtensor<[?,256],si64>, %arg2: !torch.vtensor<[?,256],si64>, %arg3: !torch.vtensor<[?,256],si64>) -> (!torch.vtensor<[?,256,768],f32> ) attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "tf2onnx", torch.onnx_meta.producer_version = "1.5.2"} {
%4 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64>
%5 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<1> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64>
%6 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64>
%7 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64>
%8 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<1> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64>
%9 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64>
%13 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<2xsi64>} : () -> !torch.vtensor<[2],si64>
%14 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<2xsi64>} : () -> !torch.vtensor<[2],si64>
%18 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0.0> : tensor<2x768xf32>} : () -> !torch.vtensor<[2,768],f32>
%483 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0.0> : tensor<30522x768xf32>} : () -> !torch.vtensor<[30522,768],f32>
%484 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0.0> : tensor<2x768xf32>} : () -> !torch.vtensor<[2,768],f32>
%485 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0.0> : tensor<512x768xf32>} : () -> !torch.vtensor<[512,768],f32>
%486 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0.0> : tensor<f32>} : () -> !torch.vtensor<[],f32>
%487 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0.0> : tensor<f32>} : () -> !torch.vtensor<[],f32>
%488 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_bert_embeddings_one_hot_depth_0> : tensor<si32>} : () -> !torch.vtensor<[],si32>
%489 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_bert_embeddings_Reshape_4_shape_0> : tensor<3xsi32>} : () -> !torch.vtensor<[3],si32>
%490 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_bert_embeddings_Reshape_3_shape_2_0> : tensor<si32>} : () -> !torch.vtensor<[],si32>
%491 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_bert_embeddings_Reshape_3_shape_1_0> : tensor<si32>} : () -> !torch.vtensor<[],si32>
%492 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_bert_embeddings_Reshape_2_shape_0> : tensor<1xsi32>} : () -> !torch.vtensor<[1],si32>
%493 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_bert_embeddings_Reshape_1_shape_2_0> : tensor<si32>} : () -> !torch.vtensor<[],si32>
%494 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_bert_embeddings_Reshape_1_shape_1_0> : tensor<si32>} : () -> !torch.vtensor<[],si32>
%495 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_bert_embeddings_Reshape_shape_0> : tensor<1xsi32>} : () -> !torch.vtensor<[1],si32>
%499 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_bert_embeddings_ExpandDims__48> : tensor<3xsi64>} : () -> !torch.vtensor<[3],si64>
%500 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_Reshape_1_shape_2_0> : tensor<si32>} : () -> !torch.vtensor<[],si32>
%501 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_Reshape_1_shape_1_0> : tensor<si32>} : () -> !torch.vtensor<[],si32>
%502 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_Reshape_shape_1_0> : tensor<si32>} : () -> !torch.vtensor<[],si32>
%811 = torch.operator "onnx.Slice"(%485, %14, %13) : (!torch.vtensor<[512,768],f32>, !torch.vtensor<[2],si64>, !torch.vtensor<[2],si64>) -> !torch.vtensor<[256,768],f32>
%812 = torch.operator "onnx.Cast"(%489) {torch.onnx.to = 7 : si64} : (!torch.vtensor<[3],si32>) -> !torch.vtensor<[3],si64>
%813 = torch.operator "onnx.Reshape"(%811, %812) : (!torch.vtensor<[256,768],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[1,256,768],f32>
%814 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64>
%815 = torch.operator "onnx.Unsqueeze"(%490, %814) : (!torch.vtensor<[],si32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1],si32>
%816 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64>
%817 = torch.operator "onnx.Unsqueeze"(%491, %816) : (!torch.vtensor<[],si32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1],si32>
%818 = torch.operator "onnx.Cast"(%492) {torch.onnx.to = 7 : si64} : (!torch.vtensor<[1],si32>) -> !torch.vtensor<[1],si64>
%819 = torch.operator "onnx.Reshape"(%arg1, %818) : (!torch.vtensor<[?,256],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[?],si64>
%820 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64>
%821 = torch.operator "onnx.Unsqueeze"(%493, %820) : (!torch.vtensor<[],si32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1],si32>
%822 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64>
%823 = torch.operator "onnx.Unsqueeze"(%494, %822) : (!torch.vtensor<[],si32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1],si32>
%824 = torch.operator "onnx.Cast"(%495) {torch.onnx.to = 7 : si64} : (!torch.vtensor<[1],si32>) -> !torch.vtensor<[1],si64>
%825 = torch.operator "onnx.Reshape"(%arg3, %499) : (!torch.vtensor<[?,256],si64>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[?,256,1],si64>
%826 = torch.operator "onnx.Shape"(%825) : (!torch.vtensor<[?,256,1],si64>) -> !torch.vtensor<[3],si64>
%827 = torch.operator "onnx.Cast"(%826) {torch.onnx.to = 1 : si64} : (!torch.vtensor<[3],si64>) -> !torch.vtensor<[3],f32>
%828 = torch.operator "onnx.Slice"(%827, %9, %8, %7) : (!torch.vtensor<[3],f32>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1],f32>
%829 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64>
%830 = torch.operator "onnx.Squeeze"(%828, %829) : (!torch.vtensor<[1],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[],f32>
%831 = torch.operator "onnx.Cast"(%830) {torch.onnx.to = 6 : si64} : (!torch.vtensor<[],f32>) -> !torch.vtensor<[],si32>
%832 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64>
%833 = torch.operator "onnx.Unsqueeze"(%831, %832) : (!torch.vtensor<[],si32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1],si32>
%834 = torch.operator "onnx.Concat"(%833, %823, %821) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[1],si32>, !torch.vtensor<[1],si32>, !torch.vtensor<[1],si32>) -> !torch.vtensor<[3],si32>
%835 = torch.operator "onnx.Cast"(%834) {torch.onnx.to = 7 : si64} : (!torch.vtensor<[3],si32>) -> !torch.vtensor<[3],si64>
%836 = torch.operator "onnx.Reshape"(%825, %824) : (!torch.vtensor<[?,256,1],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[?],si64>
%837 = torch.operator "onnx.Gather"(%483, %836) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[30522,768],f32>, !torch.vtensor<[?],si64>) -> !torch.vtensor<[?,768],f32>
%838 = torch.operator "onnx.Reshape"(%837, %835) : (!torch.vtensor<[?,768],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[?,256,768],f32>
%839 = torch.operator "onnx.Shape"(%838) : (!torch.vtensor<[?,256,768],f32>) -> !torch.vtensor<[3],si64>
%840 = torch.operator "onnx.Cast"(%839) {torch.onnx.to = 1 : si64} : (!torch.vtensor<[3],si64>) -> !torch.vtensor<[3],f32>
%841 = torch.operator "onnx.Slice"(%840, %6, %5, %4) : (!torch.vtensor<[3],f32>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1],f32>
%842 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64>
%843 = torch.operator "onnx.Squeeze"(%841, %842) : (!torch.vtensor<[1],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[],f32>
%844 = torch.operator "onnx.Cast"(%843) {torch.onnx.to = 6 : si64} : (!torch.vtensor<[],f32>) -> !torch.vtensor<[],si32>
%845 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64>
%846 = torch.operator "onnx.Unsqueeze"(%844, %845) : (!torch.vtensor<[],si32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1],si32>
%847 = torch.operator "onnx.Concat"(%846, %817, %815) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[1],si32>, !torch.vtensor<[1],si32>, !torch.vtensor<[1],si32>) -> !torch.vtensor<[3],si32>
%848 = torch.operator "onnx.Cast"(%847) {torch.onnx.to = 7 : si64} : (!torch.vtensor<[3],si32>) -> !torch.vtensor<[3],si64>
%849 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64>
%850 = torch.operator "onnx.Unsqueeze"(%487, %849) : (!torch.vtensor<[],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1],f32>
%851 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64>
%852 = torch.operator "onnx.Unsqueeze"(%486, %851) : (!torch.vtensor<[],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1],f32>
%853 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64>
%854 = torch.operator "onnx.Unsqueeze"(%488, %853) : (!torch.vtensor<[],si32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1],si32>
%862 = torch.operator "onnx.Concat"(%850, %852) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>) -> !torch.vtensor<[2],f32>
%863 = torch.operator "onnx.OneHot"(%819, %854, %862) {torch.onnx.axis = -1 : si64} : (!torch.vtensor<[?],si64>, !torch.vtensor<[1],si32>, !torch.vtensor<[2],f32>) -> !torch.vtensor<[?,?],f32>
%864 = torch.operator "onnx.MatMul"(%863, %484) : (!torch.vtensor<[?,?],f32>, !torch.vtensor<[2,768],f32>) -> !torch.vtensor<[?,768],f32>
%865 = torch.operator "onnx.Reshape"(%864, %848) : (!torch.vtensor<[?,768],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[?,256,768],f32>
%866 = torch.operator "onnx.Add"(%838, %865) : (!torch.vtensor<[?,256,768],f32>, !torch.vtensor<[?,256,768],f32>) -> !torch.vtensor<[?,256,768],f32>
%867 = torch.operator "onnx.Add"(%866, %813) : (!torch.vtensor<[?,256,768],f32>, !torch.vtensor<[1,256,768],f32>) -> !torch.vtensor<[?,256,768],f32>
return %867: !torch.vtensor<[?,256,768],f32>
}
}
{-#
dialect_resources: {
builtin: {
_bert_embeddings_one_hot_on_value_0: "0x080000000000803F",
_bert_embeddings_one_hot_off_value_0: "0x0800000000000000",
_bert_embeddings_one_hot_depth_0: "0x0800000002000000",
_bert_embeddings_Reshape_4_shape_0: "0x08000000010000000001000000030000",
_bert_embeddings_Reshape_3_shape_2_0: "0x0800000000030000",
_bert_embeddings_Reshape_3_shape_1_0: "0x0800000000010000",
_bert_embeddings_Reshape_2_shape_0: "0x08000000FFFFFFFF",
_bert_embeddings_Reshape_1_shape_2_0: "0x0800000000030000",
_bert_embeddings_Reshape_1_shape_1_0: "0x0800000000010000",
_bert_embeddings_Reshape_shape_0: "0x08000000FFFFFFFF",
_bert_embeddings_LayerNorm_batchnorm_add_y_0: "0x08000000CCBC8C2B",
_bert_embeddings_ExpandDims__48: "0x08000000FFFFFFFFFFFFFFFF00010000000000000100000000000000",
_Reshape_1_shape_2_0: "0x0800000002000000",
_Reshape_1_shape_1_0: "0x0800000000010000",
_Reshape_shape_1_0: "0x0800000000030000",
_: "0x080000000000803F"
}
}
#-}
Thanks @pdhirajkumarprasad , let me verify the scalarization is working properly here.
Nice, it looks like there are just a few casts interrupting the scalarization. I can definitely fix this quickly.
What happened?
For the give IR ( IREE compiler version 20241024.1057 @ 9c5b57a8b9e6981e300df02c41a296bd49e07c99 )
Getting error as
Steps to reproduce your issue
Command
Detail log:
dump.log
What component(s) does this issue relate to?
Compiler
Version information
No response
Additional context
No response