iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.86k stars 625 forks source link

[cpu] failed to legalize unresolved materialization from ('i64') to 'index' that remained live after conversion #18899

Open pdhirajkumarprasad opened 1 month ago

pdhirajkumarprasad commented 1 month ago

What happened?

For the give IR ( IREE compiler version 20241024.1057 @ 9c5b57a8b9e6981e300df02c41a296bd49e07c99 )

module {
  func.func @main_graph(%arg0: !torch.vtensor<[?,?],si64>, %arg1: !torch.vtensor<[?,?],si64>, %arg2: !torch.vtensor<[?,?],si64>, %arg3: !torch.vtensor<[?,?,?],si64>, %arg4: !torch.vtensor<[4],si64>, %arg5: !torch.vtensor<[2],si64>) -> !torch.vtensor<[?,?,128,384],i1>   attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.4.0"} {
    %1 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<si64>} : () -> !torch.vtensor<[],si64> 
    %2 = torch.operator "onnx.Pad"(%arg1, %arg4, %1) {torch.onnx.mode = "constant"} : (!torch.vtensor<[?,?],si64>, !torch.vtensor<[4],si64>, !torch.vtensor<[],si64>) -> !torch.vtensor<[?,?],si64> 
    %3 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<1> : tensor<si64>} : () -> !torch.vtensor<[],si64> 
    %4 = torch.operator "onnx.Gather"(%arg5, %3) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[2],si64>, !torch.vtensor<[],si64>) -> !torch.vtensor<[],si64> 
    %5 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<si64>} : () -> !torch.vtensor<[],si64> 
    %6 = torch.operator "onnx.Div"(%4, %5) : (!torch.vtensor<[],si64>, !torch.vtensor<[],si64>) -> !torch.vtensor<[],si64> 
    %7 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %8 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %9 = torch.operator "onnx.Unsqueeze"(%6, %8) : (!torch.vtensor<[],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1],si64> 
    %10 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<8> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %11 = torch.operator "onnx.Concat"(%7, %9, %10) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3],si64> 
    %12 = torch.operator "onnx.Reshape"(%2, %11) {torch.onnx.allowzero = 0 : si64} : (!torch.vtensor<[?,?],si64>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[?,?,?],si64> 
    %13 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<1> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %14 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<2> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %15 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %16 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<1> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %17 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<-1> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %18 = torch.operator "onnx.Unsqueeze"(%12, %17) : (!torch.vtensor<[?,?,?],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[?,?,?,1],si64> 
    %19 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<2> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %20 = torch.operator "onnx.Unsqueeze"(%arg3, %19) : (!torch.vtensor<[?,?,?],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[?,?,1,?],si64> 
    %21 = torch.operator "onnx.Cast"(%18) {torch.onnx.to = 9 : si64} : (!torch.vtensor<[?,?,?,1],si64>) -> !torch.vtensor<[?,?,?,1],i1> 
    %22 = torch.operator "onnx.Cast"(%20) {torch.onnx.to = 9 : si64} : (!torch.vtensor<[?,?,1,?],si64>) -> !torch.vtensor<[?,?,1,?],i1> 
    %23 = torch.operator "onnx.And"(%21, %22) : (!torch.vtensor<[?,?,?,1],i1>, !torch.vtensor<[?,?,1,?],i1>) -> !torch.vtensor<[?,?,?,?],i1> 
    %24 = torch.operator "onnx.Cast"(%23) {torch.onnx.to = 9 : si64} : (!torch.vtensor<[?,?,?,?],i1>) -> !torch.vtensor<[?,?,?,?],i1> 
    %25 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1x1x128x384xi1>} : () -> !torch.vtensor<[1,1,128,384],i1> 
    %26 = torch.operator "onnx.And"(%24, %25) : (!torch.vtensor<[?,?,?,?],i1>, !torch.vtensor<[1,1,128,384],i1>) -> !torch.vtensor<[?,?,128,384],i1> 
    return %26 : !torch.vtensor<[?,?,128,384],i1>
  }
}

Getting error as

<unknown>:0: error: failed to legalize unresolved materialization from ('i64') to 'index' that remained live after conversion
<unknown>:0: note: see current operation: %6 = "builtin.unrealized_conversion_cast"(%5) : (i64) -> index

Steps to reproduce your issue

Command

iree-compile model.torch_onnx.mlir --iree-hal-target-backends=llvm-cpu -o comp.vmfb --iree-llvmcpu-target-cpu=host

Detail log:

dump.log

What component(s) does this issue relate to?

Compiler

Version information

No response

Additional context

No response

MaheshRavishankar commented 1 month ago

@pashu123 looks like the scf.forall resolution not kicking in somewhere?

pashu123 commented 1 month ago

@pashu123 looks like the scf.forall resolution not kicking in somewhere?

It's forming a dispatch with dag_root https://gist.github.com/pashu123/12dd8d3771a1a5cfd99a52ee5b001a98#file-module_main_graph-async_dispatch_3-mlir-L5 looking into it.

MaheshRavishankar commented 1 month ago

Try top of tree. Quinn fixes a few builtins

pashu123 commented 1 month ago

Try top of tree. Quinn fixes a few builtins

I am trying the top of the main. It's not coming from the BuiltIns but forming during the dispatch region creation. Full dump: https://gist.github.com/pashu123/fb9a9d29b9f199d6f10bfb3c2d55ed49 Line 47735

pashu123 commented 1 month ago
%27 = flow.dispatch.region[%23, %24, %25, %26] -> (tensor<?x?x128x384xi1>{%20, %21}) {
  %30 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} outs(%22 : tensor<?x?x128x384xi1>) {
  ^bb0(%out: i1):
    linalg.yield %false : i1
  } -> tensor<?x?x128x384xi1>
  flow.return %30 : tensor<?x?x128x384xi1>
} count(%arg8: index, %arg9: index, %arg10: index, %arg11: index) -> (index, index, index) {
  %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg8, %arg9, %arg10, %arg11
  flow.return %x, %y, %z : index, index, index
}

https://github.com/iree-org/iree/blob/fa752ae1e491a1f8fde8967bf04473c6a6c1ca18/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp#L583 It's forming here.

pashu123 commented 1 month ago

So, rather than creating count_from_dag_root, should it create a count_from_slice here? @MaheshRavishankar

MaheshRavishankar commented 1 month ago

Yes! I wonder why that is happening.

Max191 commented 1 month ago

Yes! I wonder why that is happening.

It is choosing count_from_dag_root here because the size of the tensor is data dependent. Here is the dump before FormDispatchRegions (see the op at the end of the dump):

// -----// IR Dump After FormScalarDispatchesPass (iree-dispatch-creation-form-scalar-dispatches) //----- //
util.func public @main_graph$async(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view, %arg3: !hal.buffer_view, %arg4: !hal.buffer_view, %arg5: !hal.buffer_view, %arg6: !hal.fence, %arg7: !hal.fence) -> !hal.buffer_view attributes {inlining_policy = #util.inline.never, iree.abi.model = "coarse-fences", iree.abi.stub} {
  %false = arith.constant false
  %c1 = arith.constant 1 : index
  %c0 = arith.constant 0 : index
  %c0_i64 = arith.constant 0 : i64
  %cst = arith.constant dense<8> : tensor<i64>
  %cst_0 = arith.constant dense<0> : tensor<3xi64>
  %0 = hal.buffer_view.dim<%arg1 : !hal.buffer_view>[0] : index
  %1 = hal.buffer_view.dim<%arg1 : !hal.buffer_view>[1] : index
  %2 = hal.tensor.import wait(%arg6) => %arg4 : !hal.buffer_view -> tensor<4xi64>
  %3 = hal.tensor.import wait(%arg6) => %arg5 : !hal.buffer_view -> tensor<2xi64>
  %extracted_slice = tensor.extract_slice %2[0] [1] [1] : tensor<4xi64> to tensor<i64>
  %expanded = tensor.expand_shape %extracted_slice [] output_shape [1] : tensor<i64> into tensor<1xi64>
  %extracted = tensor.extract %expanded[%c0] : tensor<1xi64>
  %extracted_slice_1 = tensor.extract_slice %2[1] [1] [1] : tensor<4xi64> to tensor<i64>
  %expanded_2 = tensor.expand_shape %extracted_slice_1 [] output_shape [1] : tensor<i64> into tensor<1xi64>
  %extracted_3 = tensor.extract %expanded_2[%c0] : tensor<1xi64>
  %extracted_slice_4 = tensor.extract_slice %2[2] [1] [1] : tensor<4xi64> to tensor<i64>
  %expanded_5 = tensor.expand_shape %extracted_slice_4 [] output_shape [1] : tensor<i64> into tensor<1xi64>
  %extracted_6 = tensor.extract %expanded_5[%c0] : tensor<1xi64>
  %extracted_slice_7 = tensor.extract_slice %2[3] [1] [1] : tensor<4xi64> to tensor<i64>
  %expanded_8 = tensor.expand_shape %extracted_slice_7 [] output_shape [1] : tensor<i64> into tensor<1xi64>
  %extracted_9 = tensor.extract %expanded_8[%c0] : tensor<1xi64>
  %4 = arith.index_cast %extracted : i64 to index
  %5 = arith.index_cast %extracted_6 : i64 to index
  %6 = arith.index_cast %extracted_3 : i64 to index
  %7 = arith.index_cast %extracted_9 : i64 to index
  %8 = tensor.empty() : tensor<i64>
  %9 = flow.dispatch.region -> (tensor<i64>) {
    %26 = linalg.generic {indexing_maps = [affine_map<() -> ()>], iterator_types = []} outs(%8 : tensor<i64>) {
    ^bb0(%out: i64):
      %extracted_20 = tensor.extract %3[%c1] : tensor<2xi64>
      %27 = arith.divsi %extracted_20, %c0_i64 : i64
      linalg.yield %27 : i64
    } -> tensor<i64>
    flow.return %26 : tensor<i64>
  } count() -> (index, index, index) {
    %c1_20 = arith.constant 1 : index
    flow.return %c1_20, %c1_20, %c1_20 : index, index, index
  }
  %inserted_slice = tensor.insert_slice %9 into %cst_0[1] [1] [1] : tensor<i64> into tensor<3xi64>
  %inserted_slice_10 = tensor.insert_slice %cst into %inserted_slice[2] [1] [1] : tensor<i64> into tensor<3xi64>
  %extracted_slice_11 = tensor.extract_slice %inserted_slice_10[0] [1] [1] : tensor<3xi64> to tensor<i64>
  %expanded_12 = tensor.expand_shape %extracted_slice_11 [] output_shape [1] : tensor<i64> into tensor<1xi64>
  %extracted_13 = tensor.extract %expanded_12[%c0] : tensor<1xi64>
  %10 = arith.cmpi eq, %extracted_13, %c0_i64 : i64
  %11 = arith.addi %4, %5 : index
  %12 = arith.addi %11, %0 : index
  %13 = arith.index_cast %12 : index to i64
  %14 = flow.dispatch.region -> (tensor<i64>) {
    %26 = linalg.generic {indexing_maps = [affine_map<() -> ()>], iterator_types = []} outs(%8 : tensor<i64>) {
    ^bb0(%out: i64):
      %27 = arith.select %10, %13, %extracted_13 : i64
      linalg.yield %27 : i64
    } -> tensor<i64>
    flow.return %26 : tensor<i64>
  } count() -> (index, index, index) {
    %c1_20 = arith.constant 1 : index
    flow.return %c1_20, %c1_20, %c1_20 : index, index, index
  }
  %extracted_14 = tensor.extract %14[] : tensor<i64>
  %extracted_slice_15 = tensor.extract_slice %inserted_slice_10[1] [1] [1] : tensor<3xi64> to tensor<i64>
  %expanded_16 = tensor.expand_shape %extracted_slice_15 [] output_shape [1] : tensor<i64> into tensor<1xi64>
  %extracted_17 = tensor.extract %expanded_16[%c0] : tensor<1xi64>
  %15 = arith.cmpi eq, %extracted_17, %c0_i64 : i64
  %16 = arith.addi %6, %7 : index
  %17 = arith.addi %16, %1 : index
  %18 = arith.index_cast %17 : index to i64
  %19 = flow.dispatch.region -> (tensor<i64>) {
    %26 = linalg.generic {indexing_maps = [affine_map<() -> ()>], iterator_types = []} outs(%8 : tensor<i64>) {
    ^bb0(%out: i64):
      %27 = arith.select %15, %18, %extracted_17 : i64
      linalg.yield %27 : i64
    } -> tensor<i64>
    flow.return %26 : tensor<i64>
  } count() -> (index, index, index) {
    %c1_20 = arith.constant 1 : index
    flow.return %c1_20, %c1_20, %c1_20 : index, index, index
  }
  %extracted_18 = tensor.extract %19[] : tensor<i64>
  %20 = arith.index_cast %extracted_14 : i64 to index
  %21 = arith.index_cast %extracted_18 : i64 to index

  // The init operand here is dependent on `%extracted_14` and `%extracted_18`
  %22 = tensor.empty(%20, %21) : tensor<?x?x128x384xi1>
  %23 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} outs(%22 : tensor<?x?x128x384xi1>) {
  ^bb0(%out: i1):
    linalg.yield %false : i1
  } -> tensor<?x?x128x384xi1>
  %24 = hal.tensor.barrier join(%23 : tensor<?x?x128x384xi1>) => %arg7 : !hal.fence
  %dim = tensor.dim %24, %c0 : tensor<?x?x128x384xi1>
  %dim_19 = tensor.dim %24, %c1 : tensor<?x?x128x384xi1>
  %25 = hal.tensor.export %24 : tensor<?x?x128x384xi1>{%dim, %dim_19} -> !hal.buffer_view
  util.return %25 : !hal.buffer_view
}

According to the comments in RegionOpUtils.cpp, it seems this has to stay as count_from_dag_root. I'm guessing the unrealized conversion casts are coming from

  %20 = arith.index_cast %extracted_14 : i64 to index
  %21 = arith.index_cast %extracted_18 : i64 to index

I'll look at what is happening with these ops.

Max191 commented 1 month ago

@MaheshRavishankar do we need to add a lowering for flow.dispatch.workgroup_count_from_dag_root?

MaheshRavishankar commented 1 month ago

@MaheshRavishankar do we need to add a lowering for flow.dispatch.workgroup_count_from_dag_root?

No, that really cant be handled very well with scf.forall. But thats not the issue. Hold on.

MaheshRavishankar commented 1 month ago

@pdhirajkumarprasad this is again an issue where none of this has any real computation. This is just index computation for the whole code. We need to get better at triaging this (as a team). I know it fails in codegen, but is not a codegen issue. Labeling as codegen just increases latency.

Not pointing fingers. I didnt see the actual input at all. I just saw the error message too... Just pointing out that we spent two days looking somewhere else :)

@zjgarvey I think this is yours :D

pdhirajkumarprasad commented 3 weeks ago

Currently we have following 4 models failing with above error

model--long-t5-tglobal-base-16384-book-summary--pszemraj
model--long-t5-tglobal-base-16384-booksum-V11-big_patent-V2--pszemraj
model--long-t5-tglobal-base-16384-booksum-V12--pszemraj
migraphx_bert__bertsquad-12
benvanik commented 3 weeks ago
  %19 = flow.dispatch.region -> (tensor<i64>) {
    %26 = linalg.generic {indexing_maps = [affine_map<() -> ()>], iterator_types = []} outs(%8 : tensor<i64>) {
    ^bb0(%out: i64):
      %27 = arith.select %15, %18, %extracted_17 : i64
      linalg.yield %27 : i64
    } -> tensor<i64>
    flow.return %26 : tensor<i64>
  } count() -> (index, index, index) {
    %c1_20 = arith.constant 1 : index
    flow.return %c1_20, %c1_20, %c1_20 : index, index, index
  }
  %extracted_18 = tensor.extract %19[] : tensor<i64>

This is really bad IR - we should be able to compile it, but this is really bad. All of that could be %extracted_18 = arith.select %15, %18, %extracted_17 : i64 and run literally 10000 faster. Good as a test of the compiler, but this should be P1 to fix at whatever level is best (I think linalg/detensorizing should have handled this).

zjgarvey commented 3 weeks ago

This is really bad IR - we should be able to compile it, but this is really bad. All of that could be %extracted_18 = arith.select %15, %18, %extracted_17 : i64 and run literally 10000 faster. Good as a test of the compiler, but this should be P1 to fix at whatever level is best (I think linalg/detensorizing should have handled this).

We shouldn't be getting that kind of IR in the full model since those arith.select ops should get scalarized in the full model.

The problem is that we can't scalarize from a func.func input, so a bit more IR would be necessary for an accurate reproducer (specifically the producers for %arg4 and %arg5).

zjgarvey commented 3 weeks ago

Currently we have following 4 models failing with above error

model--long-t5-tglobal-base-16384-book-summary--pszemraj
model--long-t5-tglobal-base-16384-booksum-V11-big_patent-V2--pszemraj
model--long-t5-tglobal-base-16384-booksum-V12--pszemraj
migraphx_bert__bertsquad-12

At least the model model--long-t5-tglobal-base-16384-book-summary--pszemraj seems to be failing this conversion i64 to index on an onnx.Einsum op, so it might be good to update with a better reproducer.

pdhirajkumarprasad commented 3 weeks ago

Here is new reduced IR from model without much modification so that we keep the data flow intact

module {
  func.func @tf2onnx(%arg0: !torch.vtensor<[?],si64>, %arg1: !torch.vtensor<[?,256],si64>, %arg2: !torch.vtensor<[?,256],si64>, %arg3: !torch.vtensor<[?,256],si64>) -> (!torch.vtensor<[?,256,768],f32> ) attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "tf2onnx", torch.onnx_meta.producer_version = "1.5.2"} {
    %4 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %5 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<1> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %6 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %7 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %8 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<1> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %9 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %13 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<2xsi64>} : () -> !torch.vtensor<[2],si64> 
    %14 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<2xsi64>} : () -> !torch.vtensor<[2],si64> 
    %18 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0.0> : tensor<2x768xf32>} : () -> !torch.vtensor<[2,768],f32> 
    %483 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0.0> : tensor<30522x768xf32>} : () -> !torch.vtensor<[30522,768],f32> 
    %484 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0.0> : tensor<2x768xf32>} : () -> !torch.vtensor<[2,768],f32> 
    %485 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0.0> : tensor<512x768xf32>} : () -> !torch.vtensor<[512,768],f32> 
    %486 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0.0> : tensor<f32>} : () -> !torch.vtensor<[],f32> 
    %487 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0.0> : tensor<f32>} : () -> !torch.vtensor<[],f32> 
    %488 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_bert_embeddings_one_hot_depth_0> : tensor<si32>} : () -> !torch.vtensor<[],si32> 
    %489 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_bert_embeddings_Reshape_4_shape_0> : tensor<3xsi32>} : () -> !torch.vtensor<[3],si32> 
    %490 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_bert_embeddings_Reshape_3_shape_2_0> : tensor<si32>} : () -> !torch.vtensor<[],si32> 
    %491 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_bert_embeddings_Reshape_3_shape_1_0> : tensor<si32>} : () -> !torch.vtensor<[],si32> 
    %492 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_bert_embeddings_Reshape_2_shape_0> : tensor<1xsi32>} : () -> !torch.vtensor<[1],si32> 
    %493 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_bert_embeddings_Reshape_1_shape_2_0> : tensor<si32>} : () -> !torch.vtensor<[],si32> 
    %494 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_bert_embeddings_Reshape_1_shape_1_0> : tensor<si32>} : () -> !torch.vtensor<[],si32> 
    %495 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_bert_embeddings_Reshape_shape_0> : tensor<1xsi32>} : () -> !torch.vtensor<[1],si32> 
    %499 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_bert_embeddings_ExpandDims__48> : tensor<3xsi64>} : () -> !torch.vtensor<[3],si64> 
    %500 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_Reshape_1_shape_2_0> : tensor<si32>} : () -> !torch.vtensor<[],si32> 
    %501 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_Reshape_1_shape_1_0> : tensor<si32>} : () -> !torch.vtensor<[],si32> 
    %502 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_Reshape_shape_1_0> : tensor<si32>} : () -> !torch.vtensor<[],si32> 
    %811 = torch.operator "onnx.Slice"(%485, %14, %13) : (!torch.vtensor<[512,768],f32>, !torch.vtensor<[2],si64>, !torch.vtensor<[2],si64>) -> !torch.vtensor<[256,768],f32> 
    %812 = torch.operator "onnx.Cast"(%489) {torch.onnx.to = 7 : si64} : (!torch.vtensor<[3],si32>) -> !torch.vtensor<[3],si64> 
    %813 = torch.operator "onnx.Reshape"(%811, %812) : (!torch.vtensor<[256,768],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[1,256,768],f32> 
    %814 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %815 = torch.operator "onnx.Unsqueeze"(%490, %814) : (!torch.vtensor<[],si32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1],si32> 
    %816 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %817 = torch.operator "onnx.Unsqueeze"(%491, %816) : (!torch.vtensor<[],si32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1],si32> 
    %818 = torch.operator "onnx.Cast"(%492) {torch.onnx.to = 7 : si64} : (!torch.vtensor<[1],si32>) -> !torch.vtensor<[1],si64> 
    %819 = torch.operator "onnx.Reshape"(%arg1, %818) : (!torch.vtensor<[?,256],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[?],si64> 
    %820 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %821 = torch.operator "onnx.Unsqueeze"(%493, %820) : (!torch.vtensor<[],si32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1],si32> 
    %822 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %823 = torch.operator "onnx.Unsqueeze"(%494, %822) : (!torch.vtensor<[],si32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1],si32> 
    %824 = torch.operator "onnx.Cast"(%495) {torch.onnx.to = 7 : si64} : (!torch.vtensor<[1],si32>) -> !torch.vtensor<[1],si64> 
    %825 = torch.operator "onnx.Reshape"(%arg3, %499) : (!torch.vtensor<[?,256],si64>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[?,256,1],si64> 
    %826 = torch.operator "onnx.Shape"(%825) : (!torch.vtensor<[?,256,1],si64>) -> !torch.vtensor<[3],si64> 
    %827 = torch.operator "onnx.Cast"(%826) {torch.onnx.to = 1 : si64} : (!torch.vtensor<[3],si64>) -> !torch.vtensor<[3],f32> 
    %828 = torch.operator "onnx.Slice"(%827, %9, %8, %7) : (!torch.vtensor<[3],f32>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1],f32> 
    %829 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %830 = torch.operator "onnx.Squeeze"(%828, %829) : (!torch.vtensor<[1],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[],f32> 
    %831 = torch.operator "onnx.Cast"(%830) {torch.onnx.to = 6 : si64} : (!torch.vtensor<[],f32>) -> !torch.vtensor<[],si32> 
    %832 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %833 = torch.operator "onnx.Unsqueeze"(%831, %832) : (!torch.vtensor<[],si32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1],si32> 
    %834 = torch.operator "onnx.Concat"(%833, %823, %821) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[1],si32>, !torch.vtensor<[1],si32>, !torch.vtensor<[1],si32>) -> !torch.vtensor<[3],si32> 
    %835 = torch.operator "onnx.Cast"(%834) {torch.onnx.to = 7 : si64} : (!torch.vtensor<[3],si32>) -> !torch.vtensor<[3],si64> 
    %836 = torch.operator "onnx.Reshape"(%825, %824) : (!torch.vtensor<[?,256,1],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[?],si64> 
    %837 = torch.operator "onnx.Gather"(%483, %836) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[30522,768],f32>, !torch.vtensor<[?],si64>) -> !torch.vtensor<[?,768],f32> 
    %838 = torch.operator "onnx.Reshape"(%837, %835) : (!torch.vtensor<[?,768],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[?,256,768],f32> 
    %839 = torch.operator "onnx.Shape"(%838) : (!torch.vtensor<[?,256,768],f32>) -> !torch.vtensor<[3],si64> 
    %840 = torch.operator "onnx.Cast"(%839) {torch.onnx.to = 1 : si64} : (!torch.vtensor<[3],si64>) -> !torch.vtensor<[3],f32> 
    %841 = torch.operator "onnx.Slice"(%840, %6, %5, %4) : (!torch.vtensor<[3],f32>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1],f32> 
    %842 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %843 = torch.operator "onnx.Squeeze"(%841, %842) : (!torch.vtensor<[1],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[],f32> 
    %844 = torch.operator "onnx.Cast"(%843) {torch.onnx.to = 6 : si64} : (!torch.vtensor<[],f32>) -> !torch.vtensor<[],si32> 
    %845 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %846 = torch.operator "onnx.Unsqueeze"(%844, %845) : (!torch.vtensor<[],si32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1],si32> 
    %847 = torch.operator "onnx.Concat"(%846, %817, %815) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[1],si32>, !torch.vtensor<[1],si32>, !torch.vtensor<[1],si32>) -> !torch.vtensor<[3],si32> 
    %848 = torch.operator "onnx.Cast"(%847) {torch.onnx.to = 7 : si64} : (!torch.vtensor<[3],si32>) -> !torch.vtensor<[3],si64> 
    %849 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %850 = torch.operator "onnx.Unsqueeze"(%487, %849) : (!torch.vtensor<[],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1],f32> 
    %851 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %852 = torch.operator "onnx.Unsqueeze"(%486, %851) : (!torch.vtensor<[],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1],f32> 
    %853 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %854 = torch.operator "onnx.Unsqueeze"(%488, %853) : (!torch.vtensor<[],si32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1],si32> 
    %862 = torch.operator "onnx.Concat"(%850, %852) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>) -> !torch.vtensor<[2],f32> 
    %863 = torch.operator "onnx.OneHot"(%819, %854, %862) {torch.onnx.axis = -1 : si64} : (!torch.vtensor<[?],si64>, !torch.vtensor<[1],si32>, !torch.vtensor<[2],f32>) -> !torch.vtensor<[?,?],f32> 
    %864 = torch.operator "onnx.MatMul"(%863, %484) : (!torch.vtensor<[?,?],f32>, !torch.vtensor<[2,768],f32>) -> !torch.vtensor<[?,768],f32> 
    %865 = torch.operator "onnx.Reshape"(%864, %848) : (!torch.vtensor<[?,768],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[?,256,768],f32> 
    %866 = torch.operator "onnx.Add"(%838, %865) : (!torch.vtensor<[?,256,768],f32>, !torch.vtensor<[?,256,768],f32>) -> !torch.vtensor<[?,256,768],f32> 
    %867 = torch.operator "onnx.Add"(%866, %813) : (!torch.vtensor<[?,256,768],f32>, !torch.vtensor<[1,256,768],f32>) -> !torch.vtensor<[?,256,768],f32> 
    return %867: !torch.vtensor<[?,256,768],f32>
  }
}

{-#
  dialect_resources: {
    builtin: {
      _bert_embeddings_one_hot_on_value_0: "0x080000000000803F",
      _bert_embeddings_one_hot_off_value_0: "0x0800000000000000",
      _bert_embeddings_one_hot_depth_0: "0x0800000002000000",
      _bert_embeddings_Reshape_4_shape_0: "0x08000000010000000001000000030000",
      _bert_embeddings_Reshape_3_shape_2_0: "0x0800000000030000",
      _bert_embeddings_Reshape_3_shape_1_0: "0x0800000000010000",
      _bert_embeddings_Reshape_2_shape_0: "0x08000000FFFFFFFF",
      _bert_embeddings_Reshape_1_shape_2_0: "0x0800000000030000",
      _bert_embeddings_Reshape_1_shape_1_0: "0x0800000000010000",
      _bert_embeddings_Reshape_shape_0: "0x08000000FFFFFFFF",
      _bert_embeddings_LayerNorm_batchnorm_add_y_0: "0x08000000CCBC8C2B",
      _bert_embeddings_ExpandDims__48: "0x08000000FFFFFFFFFFFFFFFF00010000000000000100000000000000",
      _Reshape_1_shape_2_0: "0x0800000002000000",
      _Reshape_1_shape_1_0: "0x0800000000010000",
      _Reshape_shape_1_0: "0x0800000000030000",
      _: "0x080000000000803F"
    }
  }
#-}
zjgarvey commented 3 weeks ago

Thanks @pdhirajkumarprasad , let me verify the scalarization is working properly here.

zjgarvey commented 3 weeks ago

Nice, it looks like there are just a few casts interrupting the scalarization. I can definitely fix this quickly.