microsoft / onnxruntime

ONNX Runtime: cross-platform, high performance ML inferencing and training accelerator
https://onnxruntime.ai
MIT License
14.09k stars 2.84k forks source link

op.SequenceEmpty(dtype=xxx) cannot be set to float16. #16846

Open xiaowuhu opened 1 year ago

xiaowuhu commented 1 year ago

Describe the issue

op.SequenceEmpty(dtype=xxx) cannot be set to float16, it will ignore the dtype and result as float32.

To reproduce

from onnxscript import FLOAT, FLOAT16, opset18 as op

seq = op.SequenceEmpty(dtype=FLOAT16.dtype)
x = op.Cast(op.Constant(value_float=1.0), to=FLOAT16.dtype)
op.SequenceInsert(seq, x)  # This will failed

Error

Traceback (most recent call last):
  File "/home/justinchu/anaconda3/envs/pytorch/lib/python3.10/site-packages/onnxscript/evaluator.py", line 446, in _call_ort
    session = ort.InferenceSession(
  File "/home/justinchu/anaconda3/envs/pytorch/lib/python3.10/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 360, in __init__
    self._create_inference_session(providers, provider_options, disabled_optimizers)
  File "/home/justinchu/anaconda3/envs/pytorch/lib/python3.10/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 399, in _create_inference_session
    sess = C.InferenceSession(session_options, self._model_bytes, False, self._read_config_from_model)
onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Node () Op (SequenceInsert) [TypeInferenceError] Input Sequence and Tensor are expected to have the same elem type. Sequence=1 Tensor=10

Urgency

No response

Platform

Windows

OS Version

11

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

1.14

ONNX Runtime API

Python

Architecture

X64

Execution Provider

Default CPU

Execution Provider Library Version

No response

xadupre commented 1 year ago

What is the error message? Does it work with x = op.Cast(op.Constant(value_floats=[1.0]), to=FLOAT16.dtype)?

pranavsharma commented 1 year ago

Is this an issue with onnxruntime or onnx script? Please file the issue in an appropriate place.

justinchuby commented 1 year ago

Updated issue with error message. I think this is an issue with onnx shape inferencing. @gramalingam I wonder if this is part of the fix we have since 1.14?

justinchuby commented 1 year ago

The error manifests itself in test like so:

Summary

ONNX Runtime raises [ONNXRuntimeError] : 1 : FAIL : Node (_0x7b6efe0_n21) Op (Loop) [TypeInferenceError] Graph attribute inferencing failed: Node (_0x7b6efe0_n12) Op (If) [TypeInferenceError] Mismatched tensor element type: source=10 target=1 when executing test ops_test.TestOutputConsistencyFullGraphCPU.test_output_match_opinfo__ops_aten_embedding_bag_cpu_float16 in ONNX Script TorchLib.

To recreate this report, use

CREATE_REPRODUCTION_REPORT=1 python -m pytest onnxscript/tests/function_libs/torch_lib/ops_test.py -k test_output_match_opinfo__ops_aten_embedding_bag_cpu_float16

To reproduce

import google.protobuf.text_format
import numpy as np
from numpy import array, float16, float32, float64, int32, int64
import onnx
import onnxruntime as ort

# Run n times
N = 1

onnx_model_text = """
ir_version: 8
producer_name: "pytorch"
producer_version: "2.1.0"
graph {
  node {
    output: "_val_3"
    name: "Constant_0"
    op_type: "Constant"
    attribute {
      name: "value"
      t {
        dims: 1
        data_type: 7
        raw_data: "\377\377\377\377\377\377\377\377"
      }
      type: TENSOR
    }
    doc_string: ""
  }
  node {
    input: "input_1"
    input: "_val_3"
    output: "_val_4"
    name: "Reshape_1"
    op_type: "Reshape"
    attribute {
      name: "allowzero"
      i: 0
      type: INT
    }
    doc_string: ""
  }
  node {
    input: "_val_4"
    output: "_val_5"
    name: "Shape_2"
    op_type: "Shape"
    attribute {
      name: "start"
      i: 0
      type: INT
    }
    doc_string: ""
  }
  node {
    output: "_val_6"
    name: "Constant_3"
    op_type: "Constant"
    attribute {
      name: "value"
      t {
        data_type: 7
        raw_data: "\001\000\000\000\000\000\000\000"
      }
      type: TENSOR
    }
    doc_string: ""
  }
  node {
    input: "_val_6"
    input: "_val_5"
    output: "_val_7"
    name: "Expand_4"
    op_type: "Expand"
    doc_string: ""
  }
  node {
    input: "_val_7"
    input: "input_0"
    output: "_val_8"
    name: "CastLike_5"
    op_type: "CastLike"
    doc_string: ""
  }
  node {
    input: "input_0"
    input: "input_1"
    input: "offsets"
    input: "_val_8"
    output: "_val_9"
    name: "_aten_embedding_bag_onnx_6"
    op_type: "_aten_embedding_bag_onnx"
    attribute {
      name: "include_last_offset"
      i: 1
      type: INT
    }
    attribute {
      name: "mode"
      i: 0
      type: INT
    }
    doc_string: ""
    domain: "pkg.onnxscript.torch_lib"
  }
  node {
    input: "input_1"
    output: "_val_10"
    name: "Shape_7"
    op_type: "Shape"
    attribute {
      name: "end"
      i: 0
      type: INT
    }
    attribute {
      name: "start"
      i: 0
      type: INT
    }
    doc_string: ""
  }
  node {
    input: "offsets"
    output: "_val_11"
    name: "Shape_8"
    op_type: "Shape"
    attribute {
      name: "start"
      i: 0
      type: INT
    }
    doc_string: ""
  }
  node {
    output: "_val_12"
    name: "Constant_9"
    op_type: "Constant"
    attribute {
      name: "value"
      t {
        data_type: 7
        raw_data: "\000\000\000\000\000\000\000\000"
      }
      type: TENSOR
    }
    doc_string: ""
  }
  node {
    input: "_val_12"
    input: "_val_11"
    output: "_val_13"
    name: "Expand_10"
    op_type: "Expand"
    doc_string: ""
  }
  node {
    input: "offsets"
    output: "_val_14"
    name: "Shape_11"
    op_type: "Shape"
    attribute {
      name: "start"
      i: 0
      type: INT
    }
    doc_string: ""
  }
  node {
    output: "_val_15"
    name: "Constant_12"
    op_type: "Constant"
    attribute {
      name: "value"
      t {
        data_type: 7
        raw_data: "\000\000\000\000\000\000\000\000"
      }
      type: TENSOR
    }
    doc_string: ""
  }
  node {
    input: "_val_15"
    input: "_val_14"
    output: "_val_16"
    name: "Expand_13"
    op_type: "Expand"
    doc_string: ""
  }
  name: "torch_jit"
  input {
    name: "input_0"
    type {
      tensor_type {
        elem_type: 10
        shape {
          dim {
            dim_value: 10
          }
          dim {
            dim_value: 5
          }
        }
      }
    }
  }
  input {
    name: "input_1"
    type {
      tensor_type {
        elem_type: 7
        shape {
          dim {
            dim_value: 5
          }
        }
      }
    }
  }
  input {
    name: "offsets"
    type {
      tensor_type {
        elem_type: 7
        shape {
          dim {
            dim_value: 3
          }
        }
      }
    }
  }
  output {
    name: "_val_9"
    type {
      tensor_type {
        elem_type: 10
        shape {
          dim {
            dim_value: 2
          }
          dim {
            dim_value: 5
          }
        }
      }
    }
  }
  output {
    name: "_val_10"
    type {
      tensor_type {
        elem_type: 7
        shape {
          dim {
            dim_value: 0
          }
        }
      }
    }
  }
  output {
    name: "_val_13"
    type {
      tensor_type {
        elem_type: 7
        shape {
          dim {
            dim_value: 3
          }
        }
      }
    }
  }
  output {
    name: "_val_16"
    type {
      tensor_type {
        elem_type: 7
        shape {
          dim {
            dim_value: 3
          }
        }
      }
    }
  }
  value_info {
    name: "_val_3"
    type {
      tensor_type {
        elem_type: 7
        shape {
          dim {
            dim_value: 1
          }
        }
      }
    }
  }
  value_info {
    name: "_val_4"
    type {
      tensor_type {
        elem_type: 7
        shape {
          dim {
            dim_value: 5
          }
        }
      }
    }
  }
  value_info {
    name: "_val_5"
    type {
      tensor_type {
        elem_type: 7
        shape {
          dim {
            dim_value: 1
          }
        }
      }
    }
  }
  value_info {
    name: "_val_6"
    type {
      tensor_type {
        elem_type: 7
        shape {
        }
      }
    }
  }
  value_info {
    name: "_val_7"
    type {
      tensor_type {
        elem_type: 7
        shape {
          dim {
            dim_value: 5
          }
        }
      }
    }
  }
  value_info {
    name: "_val_8"
    type {
      tensor_type {
        elem_type: 10
        shape {
          dim {
            dim_value: 5
          }
        }
      }
    }
  }
  value_info {
    name: "_val_11"
    type {
      tensor_type {
        elem_type: 7
        shape {
          dim {
            dim_value: 1
          }
        }
      }
    }
  }
  value_info {
    name: "_val_12"
    type {
      tensor_type {
        elem_type: 7
        shape {
        }
      }
    }
  }
  value_info {
    name: "_val_14"
    type {
      tensor_type {
        elem_type: 7
        shape {
          dim {
            dim_value: 1
          }
        }
      }
    }
  }
  value_info {
    name: "_val_15"
    type {
      tensor_type {
        elem_type: 7
        shape {
        }
      }
    }
  }
}
opset_import {
  domain: "pkg.onnxscript.torch_lib"
  version: 1
}
opset_import {
  domain: ""
  version: 18
}
functions {
  name: "_aten_embedding_bag_onnx"
  input: "weight"
  input: "indices"
  input: "offsets"
  input: "per_sample_weights"
  output: "return_val"
  attribute: "mode"
  attribute: "include_last_offset"
  node {
    output: "neg_1"
    name: "n0"
    op_type: "Constant"
    attribute {
      name: "value_ints"
      ints: -1
      type: INTS
    }
    domain: ""
  }
  node {
    input: "indices"
    input: "neg_1"
    output: "indices_1d"
    name: "n1"
    op_type: "Reshape"
    domain: ""
  }
  node {
    input: "weight"
    input: "indices_1d"
    output: "new_weight"
    name: "n2"
    op_type: "Gather"
    domain: ""
  }
  node {
    output: "int64_1"
    name: "n3"
    op_type: "Constant"
    attribute {
      name: "value"
      t {
        data_type: 7
        int64_data: 1
        name: "int64_1"
      }
      type: TENSOR
    }
    domain: ""
  }
  node {
    input: "per_sample_weights"
    input: "int64_1"
    output: "tmp"
    name: "n4"
    op_type: "Unsqueeze"
    domain: ""
  }
  node {
    input: "new_weight"
    input: "tmp"
    output: "new_weight_0"
    name: "n5"
    op_type: "Mul"
    domain: ""
  }
  node {
    input: "weight"
    output: "tmp_1"
    name: "n6"
    op_type: "Shape"
    attribute {
      name: "start"
      i: 1
      type: INT
    }
    domain: ""
  }
  node {
    input: "tmp_1"
    input: "neg_1"
    output: "weight_dim_1"
    name: "n7"
    op_type: "Reshape"
    domain: ""
  }
  node {
    input: "indices_1d"
    output: "indices_size"
    name: "n8"
    op_type: "Shape"
    domain: ""
  }
  node {
    input: "offsets"
    output: "tmp_2"
    name: "n9"
    op_type: "Size"
    domain: ""
  }
  node {
    input: "tmp_2"
    input: "neg_1"
    output: "num_bag"
    name: "n10"
    op_type: "Reshape"
    domain: ""
  }
  node {
    output: "include_last_offset"
    name: "n11"
    op_type: "Constant"
    attribute {
      name: "value_int"
      type: INT
      ref_attr_name: "include_last_offset"
    }
    domain: ""
  }
  node {
    input: "include_last_offset"
    output: "include_last_offset_as_bool"
    name: "n12"
    op_type: "Cast"
    attribute {
      name: "to"
      i: 9
      type: INT
    }
    domain: ""
  }
  node {
    output: "int64_True"
    name: "n13"
    op_type: "Constant"
    attribute {
      name: "value"
      t {
        data_type: 9
        int32_data: 1
        name: "int64_True"
      }
      type: TENSOR
    }
    domain: ""
  }
  node {
    input: "include_last_offset_as_bool"
    input: "int64_True"
    output: "cond"
    name: "n14"
    op_type: "Equal"
    domain: ""
  }
  node {
    input: "cond"
    output: "offsets_8"
    output: "num_bag_9"
    name: "n15"
    op_type: "If"
    attribute {
      name: "then_branch"
      g {
        node {
          output: "int64_1_3"
          name: "n0"
          op_type: "Constant"
          attribute {
            name: "value"
            t {
              data_type: 7
              int64_data: 1
              name: "int64_1_3"
            }
            type: TENSOR
          }
          domain: ""
        }
        node {
          input: "int64_1_3"
          input: "num_bag"
          output: "int64_1_3_cast"
          name: "n1"
          op_type: "CastLike"
          domain: ""
        }
        node {
          input: "num_bag"
          input: "int64_1_3_cast"
          output: "num_bag_4"
          name: "n2"
          op_type: "Sub"
          domain: ""
        }
        node {
          input: "offsets"
          output: "offsets_5"
          name: "n3"
          op_type: "Identity"
          domain: ""
        }
        name: "thenGraph_23"
        output {
          name: "offsets_5"
          type {
          }
        }
        output {
          name: "num_bag_4"
          type {
          }
        }
      }
      type: GRAPH
    }
    attribute {
      name: "else_branch"
      g {
        node {
          input: "offsets"
          input: "indices_size"
          output: "offsets_6"
          name: "n0"
          op_type: "Concat"
          attribute {
            name: "axis"
            i: 0
            type: INT
          }
          domain: ""
        }
        node {
          input: "num_bag"
          output: "num_bag_7"
          name: "n1"
          op_type: "Identity"
          domain: ""
        }
        name: "elseGraph_23"
        output {
          name: "offsets_6"
          type {
          }
        }
        output {
          name: "num_bag_7"
          type {
          }
        }
      }
      type: GRAPH
    }
    domain: ""
  }
  node {
    output: "result"
    name: "n16"
    op_type: "SequenceEmpty"
    domain: ""
  }
  node {
    output: "tmp_10"
    name: "n17"
    op_type: "Constant"
    attribute {
      name: "value_int"
      i: 0
      type: INT
    }
    domain: ""
  }
  node {
    input: "tmp_10"
    input: "neg_1"
    output: "index_tensor"
    name: "n18"
    op_type: "Reshape"
    domain: ""
  }
  node {
    input: "index_tensor"
    input: "num_bag_9"
    output: "cond_11"
    name: "n19"
    op_type: "Less"
    domain: ""
  }
  node {
    output: "true"
    name: "n20"
    op_type: "Constant"
    attribute {
      name: "value"
      t {
        data_type: 9
        int32_data: 1
        name: "true"
      }
      type: TENSOR
    }
    domain: ""
  }
  node {
    input: ""
    input: "true"
    input: "result"
    input: "index_tensor"
    output: "result_59"
    output: "index_tensor_60"
    name: "n21"
    op_type: "Loop"
    attribute {
      name: "body"
      g {
        node {
          output: "int64_1_14"
          name: "n0"
          op_type: "Constant"
          attribute {
            name: "value"
            t {
              data_type: 7
              int64_data: 1
              name: "int64_1_14"
            }
            type: TENSOR
          }
          domain: ""
        }
        node {
          input: "int64_1_14"
          input: "index_tensor_13"
          output: "int64_1_14_cast"
          name: "n1"
          op_type: "CastLike"
          domain: ""
        }
        node {
          input: "index_tensor_13"
          input: "int64_1_14_cast"
          output: "tmp_15"
          name: "n2"
          op_type: "Add"
          domain: ""
        }
        node {
          input: "offsets_8"
          input: "index_tensor_13"
          input: "tmp_15"
          output: "start"
          name: "n3"
          op_type: "Slice"
          domain: ""
        }
        node {
          output: "int64_1_16"
          name: "n4"
          op_type: "Constant"
          attribute {
            name: "value"
            t {
              data_type: 7
              int64_data: 1
              name: "int64_1_16"
            }
            type: TENSOR
          }
          domain: ""
        }
        node {
          input: "int64_1_16"
          input: "index_tensor_13"
          output: "int64_1_16_cast"
          name: "n5"
          op_type: "CastLike"
          domain: ""
        }
        node {
          input: "index_tensor_13"
          input: "int64_1_16_cast"
          output: "tmp_17"
          name: "n6"
          op_type: "Add"
          domain: ""
        }
        node {
          output: "int64_2"
          name: "n7"
          op_type: "Constant"
          attribute {
            name: "value"
            t {
              data_type: 7
              int64_data: 2
              name: "int64_2"
            }
            type: TENSOR
          }
          domain: ""
        }
        node {
          input: "int64_2"
          input: "index_tensor_13"
          output: "int64_2_cast"
          name: "n8"
          op_type: "CastLike"
          domain: ""
        }
        node {
          input: "index_tensor_13"
          input: "int64_2_cast"
          output: "tmp_18"
          name: "n9"
          op_type: "Add"
          domain: ""
        }
        node {
          input: "offsets_8"
          input: "tmp_17"
          input: "tmp_18"
          output: "end"
          name: "n10"
          op_type: "Slice"
          domain: ""
        }
        node {
          input: "start"
          input: "end"
          output: "cond_19"
          name: "n11"
          op_type: "Equal"
          domain: ""
        }
        node {
          input: "cond_19"
          output: "row_result_54"
          name: "n12"
          op_type: "If"
          attribute {
            name: "then_branch"
            g {
              node {
                output: "tmp_20"
                name: "n0"
                op_type: "Constant"
                attribute {
                  name: "value_floats"
                  floats: 0.0
                  type: FLOATS
                }
                domain: ""
              }
              node {
                output: "tmp_21"
                name: "n1"
                op_type: "Constant"
                attribute {
                  name: "value_ints"
                  ints: 1
                  type: INTS
                }
                domain: ""
              }
              node {
                input: "tmp_21"
                input: "weight_dim_1"
                output: "tmp_22"
                name: "n2"
                op_type: "Concat"
                attribute {
                  name: "axis"
                  i: 0
                  type: INT
                }
                domain: ""
              }
              node {
                input: "tmp_20"
                input: "tmp_22"
                output: "row_result"
                name: "n3"
                op_type: "Expand"
                domain: ""
              }
              name: "thenGraph_40"
              output {
                name: "row_result"
                type {
                }
              }
            }
            type: GRAPH
          }
          attribute {
            name: "else_branch"
            g {
              node {
                output: "mode"
                name: "n0"
                op_type: "Constant"
                attribute {
                  name: "value_int"
                  type: INT
                  ref_attr_name: "mode"
                }
                domain: ""
              }
              node {
                output: "int64_0"
                name: "n1"
                op_type: "Constant"
                attribute {
                  name: "value"
                  t {
                    data_type: 7
                    int64_data: 0
                    name: "int64_0"
                  }
                  type: TENSOR
                }
                domain: ""
              }
              node {
                input: "mode"
                input: "int64_0"
                output: "cond_23"
                name: "n2"
                op_type: "Equal"
                domain: ""
              }
              node {
                input: "cond_23"
                output: "row_result_53"
                name: "n3"
                op_type: "If"
                attribute {
                  name: "then_branch"
                  g {
                    node {
                      input: "new_weight_0"
                      input: "start"
                      input: "end"
                      output: "weight_rows"
                      name: "n0"
                      op_type: "Slice"
                      domain: ""
                    }
                    node {
                      output: "int64_0_1d"
                      name: "n1"
                      op_type: "Constant"
                      attribute {
                        name: "value"
                        t {
                          dims: 1
                          data_type: 7
                          int64_data: 0
                          name: "int64_0_1d"
                        }
                        type: TENSOR
                      }
                      domain: ""
                    }
                    node {
                      input: "weight_rows"
                      input: "int64_0_1d"
                      output: "row_result_24"
                      name: "n2"
                      op_type: "ReduceSum"
                      domain: ""
                    }
                    name: "thenGraph_46"
                    output {
                      name: "row_result_24"
                      type {
                      }
                    }
                  }
                  type: GRAPH
                }
                attribute {
                  name: "else_branch"
                  g {
                    node {
                      output: "mode_25"
                      name: "n0"
                      op_type: "Constant"
                      attribute {
                        name: "value_int"
                        type: INT
                        ref_attr_name: "mode"
                      }
                      domain: ""
                    }
                    node {
                      output: "int64_1_26"
                      name: "n1"
                      op_type: "Constant"
                      attribute {
                        name: "value"
                        t {
                          data_type: 7
                          int64_data: 1
                          name: "int64_1_26"
                        }
                        type: TENSOR
                      }
                      domain: ""
                    }
                    node {
                      input: "mode_25"
                      input: "int64_1_26"
                      output: "cond_27"
                      name: "n2"
                      op_type: "Equal"
                      domain: ""
                    }
                    node {
                      input: "cond_27"
                      output: "row_result_52"
                      name: "n3"
                      op_type: "If"
                      attribute {
                        name: "then_branch"
                        g {
                          node {
                            input: "new_weight_0"
                            input: "start"
                            input: "end"
                            output: "weight_rows_28"
                            name: "n0"
                            op_type: "Slice"
                            domain: ""
                          }
                          node {
                            output: "int64_1_29"
                            name: "n1"
                            op_type: "Constant"
                            attribute {
                              name: "value"
                              t {
                                data_type: 7
                                int64_data: 1
                                name: "int64_1_29"
                              }
                              type: TENSOR
                            }
                            domain: ""
                          }
                          node {
                            input: "int64_1_29"
                            input: "num_bag_9"
                            output: "int64_1_29_cast"
                            name: "n2"
                            op_type: "CastLike"
                            domain: ""
                          }
                          node {
                            input: "num_bag_9"
                            input: "int64_1_29_cast"
                            output: "tmp_30"
                            name: "n3"
                            op_type: "Sub"
                            domain: ""
                          }
                          node {
                            input: "index_tensor_13"
                            input: "tmp_30"
                            output: "cond_31"
                            name: "n4"
                            op_type: "Equal"
                            domain: ""
                          }
                          node {
                            input: "cond_31"
                            output: "row_result_43"
                            name: "n5"
                            op_type: "If"
                            attribute {
                              name: "then_branch"
                              g {
                                node {
                                  output: "int64_0_1d_32"
                                  name: "n0"
                                  op_type: "Constant"
                                  attribute {
                                    name: "value"
                                    t {
                                      dims: 1
                                      data_type: 7
                                      int64_data: 0
                                      name: "int64_0_1d_32"
                                    }
                                    type: TENSOR
                                  }
                                  domain: ""
                                }
                                node {
                                  input: "weight_rows_28"
                                  input: "int64_0_1d_32"
                                  output: "row_result_33"
                                  name: "n1"
                                  op_type: "ReduceSum"
                                  domain: ""
                                }
                                node {
                                  input: "indices"
                                  output: "tmp_34"
                                  name: "n2"
                                  op_type: "Shape"
                                  attribute {
                                    name: "end"
                                    i: 1
                                    type: INT
                                  }
                                  attribute {
                                    name: "start"
                                    i: 0
                                    type: INT
                                  }
                                  domain: ""
                                }
                                node {
                                  input: "tmp_34"
                                  input: "start"
                                  output: "denominator"
                                  name: "n3"
                                  op_type: "Sub"
                                  domain: ""
                                }
                                node {
                                  output: "int64_0_35"
                                  name: "n4"
                                  op_type: "Constant"
                                  attribute {
                                    name: "value"
                                    t {
                                      data_type: 7
                                      int64_data: 0
                                      name: "int64_0_35"
                                    }
                                    type: TENSOR
                                  }
                                  domain: ""
                                }
                                node {
                                  input: "int64_0_35"
                                  input: "denominator"
                                  output: "int64_0_35_cast"
                                  name: "n5"
                                  op_type: "CastLike"
                                  domain: ""
                                }
                                node {
                                  input: "denominator"
                                  input: "int64_0_35_cast"
                                  output: "cond_36"
                                  name: "n6"
                                  op_type: "Greater"
                                  domain: ""
                                }
                                node {
                                  input: "cond_36"
                                  output: "row_result_40"
                                  name: "n7"
                                  op_type: "If"
                                  attribute {
                                    name: "then_branch"
                                    g {
                                      node {
                                        input: "denominator"
                                        input: "new_weight_0"
                                        output: "tmp_37"
                                        name: "n0"
                                        op_type: "CastLike"
                                        domain: ""
                                      }
                                      node {
                                        input: "row_result_33"
                                        input: "tmp_37"
                                        output: "row_result_38"
                                        name: "n1"
                                        op_type: "Div"
                                        domain: ""
                                      }
                                      name: "thenGraph_56"
                                      output {
                                        name: "row_result_38"
                                        type {
                                        }
                                      }
                                    }
                                    type: GRAPH
                                  }
                                  attribute {
                                    name: "else_branch"
                                    g {
                                      node {
                                        input: "row_result_33"
                                        output: "row_result_39"
                                        name: "n0"
                                        op_type: "Identity"
                                        domain: ""
                                      }
                                      name: "elseGraph_56"
                                      output {
                                        name: "row_result_39"
                                        type {
                                        }
                                      }
                                    }
                                    type: GRAPH
                                  }
                                  domain: ""
                                }
                                name: "thenGraph_51"
                                output {
                                  name: "row_result_40"
                                  type {
                                  }
                                }
                              }
                              type: GRAPH
                            }
                            attribute {
                              name: "else_branch"
                              g {
                                node {
                                  output: "int64_0_1d_41"
                                  name: "n0"
                                  op_type: "Constant"
                                  attribute {
                                    name: "value"
                                    t {
                                      dims: 1
                                      data_type: 7
                                      int64_data: 0
                                      name: "int64_0_1d_41"
                                    }
                                    type: TENSOR
                                  }
                                  domain: ""
                                }
                                node {
                                  input: "weight_rows_28"
                                  input: "int64_0_1d_41"
                                  output: "row_result_42"
                                  name: "n1"
                                  op_type: "ReduceMean"
                                  domain: ""
                                }
                                name: "elseGraph_51"
                                output {
                                  name: "row_result_42"
                                  type {
                                  }
                                }
                              }
                              type: GRAPH
                            }
                            domain: ""
                          }
                          name: "thenGraph_49"
                          output {
                            name: "row_result_43"
                            type {
                            }
                          }
                        }
                        type: GRAPH
                      }
                      attribute {
                        name: "else_branch"
                        g {
                          node {
                            output: "int64_1_44"
                            name: "n0"
                            op_type: "Constant"
                            attribute {
                              name: "value"
                              t {
                                data_type: 7
                                int64_data: 1
                                name: "int64_1_44"
                              }
                              type: TENSOR
                            }
                            domain: ""
                          }
                          node {
                            input: "int64_1_44"
                            input: "num_bag_9"
                            output: "int64_1_44_cast"
                            name: "n1"
                            op_type: "CastLike"
                            domain: ""
                          }
                          node {
                            input: "num_bag_9"
                            input: "int64_1_44_cast"
                            output: "tmp_45"
                            name: "n2"
                            op_type: "Sub"
                            domain: ""
                          }
                          node {
                            input: "index_tensor_13"
                            input: "tmp_45"
                            output: "cond_46"
                            name: "n3"
                            op_type: "Equal"
                            domain: ""
                          }
                          node {
                            input: "cond_46"
                            output: "weight_rows_49"
                            name: "n4"
                            op_type: "If"
                            attribute {
                              name: "then_branch"
                              g {
                                node {
                                  input: "new_weight_0"
                                  input: "start"
                                  input: "indices_size"
                                  output: "weight_rows_47"
                                  name: "n0"
                                  op_type: "Slice"
                                  domain: ""
                                }
                                name: "thenGraph_61"
                                output {
                                  name: "weight_rows_47"
                                  type {
                                  }
                                }
                              }
                              type: GRAPH
                            }
                            attribute {
                              name: "else_branch"
                              g {
                                node {
                                  input: "new_weight_0"
                                  input: "start"
                                  input: "end"
                                  output: "weight_rows_48"
                                  name: "n0"
                                  op_type: "Slice"
                                  domain: ""
                                }
                                name: "elseGraph_61"
                                output {
                                  name: "weight_rows_48"
                                  type {
                                  }
                                }
                              }
                              type: GRAPH
                            }
                            domain: ""
                          }
                          node {
                            output: "int64_0_1d_50"
                            name: "n5"
                            op_type: "Constant"
                            attribute {
                              name: "value"
                              t {
                                dims: 1
                                data_type: 7
                                int64_data: 0
                                name: "int64_0_1d_50"
                              }
                              type: TENSOR
                            }
                            domain: ""
                          }
                          node {
                            input: "weight_rows_49"
                            input: "int64_0_1d_50"
                            output: "row_result_51"
                            name: "n6"
                            op_type: "ReduceMax"
                            domain: ""
                          }
                          name: "elseGraph_49"
                          output {
                            name: "row_result_51"
                            type {
                            }
                          }
                        }
                        type: GRAPH
                      }
                      domain: ""
                    }
                    name: "elseGraph_46"
                    output {
                      name: "row_result_52"
                      type {
                      }
                    }
                  }
                  type: GRAPH
                }
                domain: ""
              }
              name: "elseGraph_40"
              output {
                name: "row_result_53"
                type {
                }
              }
            }
            type: GRAPH
          }
          domain: ""
        }
        node {
          input: "result_12"
          input: "row_result_54"
          output: "result_55"
          name: "n13"
          op_type: "SequenceInsert"
          domain: ""
        }
        node {
          output: "int64_1_56"
          name: "n14"
          op_type: "Constant"
          attribute {
            name: "value"
            t {
              data_type: 7
              int64_data: 1
              name: "int64_1_56"
            }
            type: TENSOR
          }
          domain: ""
        }
        node {
          input: "int64_1_56"
          input: "index_tensor_13"
          output: "int64_1_56_cast"
          name: "n15"
          op_type: "CastLike"
          domain: ""
        }
        node {
          input: "index_tensor_13"
          input: "int64_1_56_cast"
          output: "index_tensor_57"
          name: "n16"
          op_type: "Add"
          domain: ""
        }
        node {
          input: "index_tensor_57"
          input: "num_bag_9"
          output: "cond_58"
          name: "n17"
          op_type: "Less"
          domain: ""
        }
        node {
          input: "cond_58"
          output: "cond_out"
          name: "n18"
          op_type: "Identity"
          domain: ""
        }
        name: "loop_body"
        input {
          name: "infinite_loop"
          type {
            tensor_type {
              elem_type: 7
              shape {
              }
            }
          }
        }
        input {
          name: "cond"
          type {
            tensor_type {
              elem_type: 9
              shape {
              }
            }
          }
        }
        input {
          name: "result_12"
          type {
          }
        }
        input {
          name: "index_tensor_13"
          type {
          }
        }
        output {
          name: "cond_out"
          type {
            tensor_type {
              elem_type: 9
              shape {
              }
            }
          }
        }
        output {
          name: "result_55"
          type {
          }
        }
        output {
          name: "index_tensor_57"
          type {
          }
        }
      }
      type: GRAPH
    }
    domain: ""
  }
  node {
    input: "result_59"
    output: "result_61"
    name: "n22"
    op_type: "ConcatFromSequence"
    attribute {
      name: "axis"
      i: 0
      type: INT
    }
    domain: ""
  }
  node {
    input: "result_61"
    input: "weight"
    output: "return_val"
    name: "n23"
    op_type: "CastLike"
    domain: ""
  }
  opset_import {
    domain: ""
    version: 18
  }
  domain: "pkg.onnxscript.torch_lib"
}

"""

ort_inputs = {'input_0': array([[ 2.664 , -6.777 ,  5.723 , -5.027 ,  0.3164],
       [ 1.397 ,  8.945 , -0.0879, -8.93  ,  2.04  ],
       [ 0.8613,  5.78  , -1.837 , -7.18  ,  1.143 ],
       [-5.457 , -7.67  , -8.53  ,  1.046 ,  5.906 ],
       [ 8.08  , -3.488 , -3.85  ,  5.484 ,  8.5   ],
       [ 3.684 ,  3.7   , -8.09  ,  7.48  , -6.777 ],
       [ 1.986 ,  0.9756, -4.527 ,  4.15  , -8.14  ],
       [-7.664 ,  7.56  , -0.9756,  7.867 , -1.292 ],
       [-6.04  ,  6.195 , -1.784 ,  1.424 ,  6.32  ],
       [-0.5977,  5.977 ,  8.74  , -7.594 ,  5.09  ]], dtype=float16), 'input_1': array([9, 1, 4, 9, 6]), 'offsets': array([0, 2, 3])}

# Set up the inference session
session_options = ort.SessionOptions()
session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL
onnx_model = onnx.ModelProto()
google.protobuf.text_format.Parse(onnx_model_text, onnx_model)

# Uncomment this line to save the model to a file for examination
# onnx.save_model(onnx_model, "test_output_match_opinfo__ops_aten_embedding_bag_cpu_float16.onnx")

onnx.checker.check_model(onnx_model)
session = ort.InferenceSession(onnx_model.SerializeToString(), session_options, providers=("CPUExecutionProvider",))

# Run the model
for _ in range(N):
    ort_outputs = session.run(None, ort_inputs)

Full error stack

[ONNXRuntimeError] : 1 : FAIL : Node (_0x7b6efe0_n21) Op (Loop) [TypeInferenceError] Graph attribute inferencing failed: Node (_0x7b6efe0_n12) Op (If) [TypeInferenceError] Mismatched tensor element type: source=10 target=1
  File "/home/justinchu/dev/onnx-script/onnxscript/tests/function_libs/torch_lib/ops_test_common.py", line 533, in _capture_graph_and_evaluate_torch_script_evaluator
    return _safe_ort_session_run(onnx_model.SerializeToString(), ort_inputs)
  File "/home/justinchu/dev/onnx-script/onnxscript/tests/function_libs/torch_lib/ops_test_common.py", line 349, in _safe_ort_session_run
    raise return_dict["error"]

Environment

OS: Linux-5.15.0-1042-azure-x86_64-with-glibc2.35
Python version: 3.10.9 (main, Jan 11 2023, 15:21:40) [GCC 11.2.0]
onnx==1.15.0.dev20230731
onnxruntime==1.15.1
numpy==1.25.1
torch==2.1.0.dev20230622+cpu