microsoft / onnxruntime

ONNX Runtime: cross-platform, high performance ML inferencing and training accelerator
https://onnxruntime.ai
MIT License
14.67k stars 2.93k forks source link

The onnx.helper make_function command strips type information leading to inference errors #18264

Open corwinjoy opened 1 year ago

corwinjoy commented 1 year ago

Describe the issue

When creating custom operations, if you wrap your operations inside a make_function call then the types get lost. This causes a type error when attempting to load the model in ONNX runtime. This is a problem for more complex nodes with a lot of data (such as TreeEnsembleRegressor) where we want to re-use and existing node with different parameters.

The place in the onnxruntime code that generates the error indicates that "this should not happen". https://github.com/microsoft/onnxruntime/blob/d8d79521ca2b266e631ac0ba7036a682ebb58b5b/onnxruntime/core/graph/graph.cc#L2358

To reproduce

import numpy as np
import onnx
import onnxruntime_extensions as ortx
from onnx import (
    ModelProto,
    TensorProto,
    checker,
    numpy_helper,
)
from onnx.helper import make_graph, make_model, make_node, make_opsetid, make_tensor_value_info, \
    make_function
import onnxruntime
from pathlib import Path
import pytest

@ortx.onnx_op(op_type="Adder",
              inputs=[ortx.PyCustomOpDef.dt_float, ortx.PyCustomOpDef.dt_float],
              outputs=[ortx.PyCustomOpDef.dt_float]
              )
def custom_adder(a, b):
    return a + b

def _create_adder_model(use_function: bool) -> ModelProto:
    custom_op_domain = 'ai.onnx.contrib'
    opset_imports = [make_opsetid("", onnx.defs.onnx_opset_version()),
                                     make_opsetid(custom_op_domain, 1)
                                     ]

    # constants
    A_value = np.array([1.0], dtype=np.float32)
    A = numpy_helper.from_array(A_value, name='A')
    B_value = np.array([2.0], dtype=np.float32)
    B = numpy_helper.from_array(B_value, name='B')

    adder_node = make_node(
        "Adder",
        domain="ai.onnx.contrib",
        inputs=['X', 'Y'],
        outputs=['Z']
    )

    # operator function wrapper
    X = make_tensor_value_info("X", TensorProto.FLOAT, ["N", 1])
    Y = make_tensor_value_info("Y", TensorProto.FLOAT, ["N", 1])
    Z = make_tensor_value_info("Z", TensorProto.FLOAT, ["N", 1])
    adder_fn = make_function(
        custom_op_domain,  # domain name
        'AdderFn',  # function name
        ['X', 'Y'],  # input names
        ['Z'],  # output names
        [adder_node],  # nodes
        opset_imports,  # opsets
        [])  # attribute names

    if use_function:
        XA_node = make_node('AdderFn', ['X', 'A'], ['XA'],
                            domain=custom_op_domain, name="XA_node")
        YB_node = make_node('AdderFn', ['Y', 'B'], ['YB'],
                            domain=custom_op_domain, name="XB_node")
    else:
        XA_node = make_node('Adder', ['X', 'A'], ['XA'],
                            domain=custom_op_domain, name="XA_node")
        YB_node = make_node('Adder', ['Y', 'B'], ['YB'],
                            domain=custom_op_domain, name="XB_node")

    sum_node = make_node('Adder', ['XA', 'YB'], ['Z'],
                         domain=custom_op_domain, name="sum_node")

    graph_def = make_graph(
        nodes=[XA_node, YB_node, sum_node],
        name="mean-leaf-model",
        inputs=[X, Y],
        outputs=[Z],
        initializer=[A, B])

    model_def = make_model(
        graph_def,
        opset_imports=opset_imports,
        functions=[adder_fn])

    checker_context = checker.C.CheckerContext()
    checker_context.ir_version = onnx.IR_VERSION
    checker_context.opset_imports = {v.domain: v.version for v in opset_imports}

    checker.check_model(model_def)
    # print(model_def)

    return model_def

def register_custom_ops_python(session_options: onnxruntime.SessionOptions):
    """
    Register the python custom operator library with an ONNX session options instance
    """
    session_options.register_custom_ops_library(ortx.get_library_path())

@pytest.mark.parametrize("use_function", [False, True])
def test_adder(tmp_path: Path, use_function: bool):
    onnx_model = _create_adder_model(use_function=use_function)

    model_path = tmp_path / "test.onnx"
    onnx.save_model(onnx_model, str(model_path), save_as_external_data=False)

    session_opts = onnxruntime.SessionOptions()
    register_custom_ops_python(session_opts)
    session = onnxruntime.InferenceSession(model_path, session_opts, providers=["CPUExecutionProvider"])

    X = np.arange(0, 5, dtype=np.float32).reshape(-1, 1)
    Y = np.arange(1, 6, dtype=np.float32).reshape(-1, 1)
    Z = X + Y + 3

    output = session.run(['Z'], {"X": X, "Y": Y})

    np.testing.assert_allclose(output[0], Z)

Urgency

No response

Platform

Linux

OS Version

22.04

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

1.16.1

ONNX Runtime API

Python

Architecture

X64

Execution Provider

Default CPU

Execution Provider Library Version

No response

corwinjoy commented 1 year ago

The failing test case generates an error message like:

 onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Load model from /tmp//test.onnx failed:Node (XA_node) output arg (XA) type inference failed
github-actions[bot] commented 11 months ago

This issue has been automatically marked as stale due to inactivity and will be closed in 7 days if no further activity occurs. If further support is needed, please provide an update and/or more details.

thiagocrepaldi commented 11 months ago

@BowenBao @gramalingam @justinchuby correct me if I am wrong, but ONNX Function does not store type information by spec.

If this issue is the same as https://github.com/onnx/onnx/issues/5487, then this is a ONNX spec limitation and ORT can't do much about it.

justinchuby commented 11 months ago

I agree this is the case.

gramalingam commented 11 months ago

ONNX functions do not store type information by design, yes. However, the issue here may have other confounding factors. The custom-ops (just like regular ops) typically have a shape-and-type-inference function that is the primary source of type information. I guess that is missing in this case? Why is that? Is that because of limitations in the type-inference for python-custom-ops?

corwinjoy commented 11 months ago

Custom ops have to declare types (as far as I know). In this case, the example custom op is explicit in its types:

@ortx.onnx_op(op_type="Adder",
              inputs=[ortx.PyCustomOpDef.dt_float, ortx.PyCustomOpDef.dt_float],
              outputs=[ortx.PyCustomOpDef.dt_float]
              )

In the above example, if the custom op is used directly, everything is fine. However, with make_function it seems that type information is erased. I think this is because make_function is a type of template operator. But, somehow it is not preserving type information, and as a result the graph cannot be compiled. (In order to compile the graph, input and output types for the graph as a whole need to be defined). I think that with normal operators, the graph compiler is able to deduce types, but here it is failing because the types are being lost. @gramalingam @justinchuby @thiagocrepaldi

gramalingam commented 11 months ago

Does the failure happen only when use_function is true, or in both cases?

There are two sources of type-information when ONNX type-inference happens. One is the type-information explicitly included in the model itself. But the second is via the type-inference methods of registered ops. I assume that ortx.onnx_op creates and registers an ONNX op-schema with the correct signature. This op-schema should be sufficient without needing to explicitly capture type-information in the function within the model.

So, the failure implies something else is going wrong as well. Eg., may be the inference-logic is not getting access to the op-schema for some reason. So, just trying to understand what could be causing this.

corwinjoy commented 11 months ago

@gramalingam The failure only happens when use_function is true. I set the example up this way to show that the problem is with make_function .