microsoft / onnxruntime-extensions

onnxruntime-extensions: A specialized pre- and post- processing library for ONNX Runtime
MIT License
295 stars 80 forks source link

"Export" Custom Ops Library After Registering Ops? #733

Closed JTunis closed 1 month ago

JTunis commented 1 month ago

I have an ONNX model composed of custom ops registered via the @onnx_op decorator like the following:

@onnx_op(op_type="Double", inputs=[PyCustomOpDef.dt_float], outputs=[PyCustomOpDef.dt_float])
def double(x: np.ndarray) -> np.ndarray:
    return x * 2

@onnx_op(op_type="Triple", inputs=[PyCustomOpDef.dt_float], outputs=[PyCustomOpDef.dt_float])
def triple(x: np.ndarray) -> np.ndarray:
    return x * 3

def make_onnx_model():
    nodes = [
        helper.make_node("Double", ["input"], ["doubled"], domain="ai.onnx.contrib"),
        helper.make_node("Triple", ["input"], ["tripled"], domain="ai.onnx.contrib"),
    ]

    input_info = helper.make_tensor_value_info(name="input", elem_type=onnx_pb.TensorProto.FLOAT, shape=[1, 8])
    doubled_info = helper.make_tensor_value_info(
        name="doubled", elem_type=onnx_pb.TensorProto.FLOAT, shape=[1, 8]
    )
    tripled_info = helper.make_tensor_value_info(
        name="tripled", elem_type=onnx_pb.TensorProto.FLOAT, shape=[1, 8]
    )

    graph = helper.make_graph(
        nodes=nodes,
        name="DoublerTripler",
        inputs=[input_info],
        outputs=[doubled_info, tripled_info],
    )

    model = helper.make_model(
        graph=graph, opset_imports=[helper.make_operatorsetid("ai.onnx.contrib", 1)], ir_version=9
    )

    onnx.save(model, "./test-model.onnx")

If I try to load the model into an inferenceSession outside of the Python session in which the ops were registered (in a service that doesn't contain the @onnx_op decorated methods, for example) I get the following error:

failed:Fatal error: ai.onnx.contrib:Double(-1) is not a registered function/op

Is there a way to export the Custom Ops Library after registering new ops with it via the @onnx_op decorator such that I'm able to use those custom ops without their implementations needing to be present in the inference service?

wenbingl commented 1 month ago

onnx_op defines a function kernel to run the onnx node on inference. Without this onnx_op function body, the ORT inference session cannot find this Python function for running the custom op node. Alternatively, you can create a C++ function for running the node and register it in ORT-extensions DLL/shared library, It also works but you need to build ort-extensions by yourself.

JTunis commented 1 month ago

Makes sense to me, thanks