onnx / onnx-mlir

Representation and Reference Lowering of ONNX Models in MLIR Compiler Infrastructure
Apache License 2.0
713 stars 306 forks source link

Adding support for ONNX ops from custom domains? #1638

Open kernhanda opened 1 year ago

kernhanda commented 1 year ago

Hi folks,

Would it be acceptable to add support for ONNX operators from custom domains (such as com.microsoft)? Is there such an extensibility point in the code already or would infrastructure need to be added? If there's any precedence of this, please let me know and I can draw inspiration.

Thanks!

caoimhinuibrian commented 1 year ago

There is some code dealing with custom ops ... maybe not complete ... PR #430 Is that the kind of thing you are thinking of?

kernhanda commented 1 year ago

If I understand correctly, #430 creates function calls for unknown operations. I was asking more about adding support for custom domain operators in the ONNX dialect (or another dialect, even) such that it gets lowered like anything else in the compilation pipeline.

caoimhinuibrian commented 1 year ago

@kernhanda Yes, that's correct. However, one could imagine lowering some set of unknown operators to regular code if desired. The way the ONNX dialect is generated is (almost entirely) by parsing the ONNX spec. Where would your domain operators come from? Would you have a special version of, say, TensorFlow that generates them?

kernhanda commented 1 year ago

The domain operators I was thinking of are the ones defined/supported by ONNXRuntime.

caoimhinuibrian commented 1 year ago

This appears to be a fairly large feature and I'm not sure how it fits with a compilation (as opposed to a runtime) framework. Perhaps you can take some time to flesh out your thoughts on it? @AlexandreEichenberger @chentong319 any comments from either of you?

AlexandreEichenberger commented 1 year ago

I would say that our overall goal would be to implement ONNX standard, and that when there are missing operations, one should strive to add them to the standard. I know that MS and its ORT has a big impact on the standard, but the goal of a standard is to add to it what is important to the community. So if some ops are important to your project, I would suggest to approach the Operator SIG and push for these ops to be added.

That said, we should definitely look into your suggestion.

There are other operators extending the MS runtime: https://github.com/microsoft/onnxruntime-extensions/blob/main/docs/custom_ops.md

Looking briefly at the extensions you mentioned:

The domain operators I was thinking of are the ones defined/supported by ONNXRuntime.

I wonder if some of these custom extensions are not early implementation of operations that are later added to the specs. If that is the case, what is your overall goal: support MS extensions to provide compatibility with ORT including the MS and Contributor's operations, or deliver to users additional functionality that is missing in ONNX standard?

Second, do you intend to provide implementation for them? If so, your goal is functional compatibility (e.g. FusedGEMM operation can be easily decomposed in a GEMM and Activation ops) or performance (in which case we really want to see that new op and it being lowered). Note that the two are not incompatible: starting with functional and then looking into optimizing the important ones.

Thanks for starting this interesting discussion.

chentong319 commented 1 year ago

The definition of onnx ops in onnx-mlir is imported from onnx package through script onnx-mlir/utils/gen_onnx_mlir.py. Are the customize ops you want to add defined with the same data structure but different domain(other than onnx and onnx_ml) or different package(e.g. onnxruntime package)? If that's true, you can start from the script to generate Op definitions.

Otherwise, you may start from src/Dialect/ONNX/AdditionalONNXOps.td to add the definition of the new Ops manually. There is no structure for Ops in ONNX dialect. We do not define subclasses, such as unary, binary, and etc. There are special interface for shape inference and type inference.

After Op definition, support for shape inference, lowering and rewriting (if any) need to be added.

kernhanda commented 1 year ago

Thanks for starting this interesting discussion.

Sure thing! Sorry it's taken me a bit to get back to this issue.

I wonder if some of these custom extensions are not early implementation of operations that are later added to the specs. If that is the case, what is your overall goal: support MS extensions to provide compatibility with ORT including the MS and Contributor's operations, or deliver to users additional functionality that is missing in ONNX standard?

The motivation is to provide an alternative way to compile models that have MS/Contributor's operations. Whether those operators end up getting standardized is beyond my control, but taking a look at the operators, it does seem clear that some of them may not be intended to be standardized any time soon since they provide very specific functionality.

Second, do you intend to provide implementation for them? If so, your goal is functional compatibility (e.g. FusedGEMM operation can be easily decomposed in a GEMM and Activation ops) or performance (in which case we really want to see that new op and it being lowered). Note that the two are not incompatible: starting with functional and then looking into optimizing the important ones.

I was thinking of doing it in a staged approach -- first to decompose the non-standard operators into existing ONNX/krnl ops that can leverage the existing pipeline and then to optimize the operations that are computationally heavy.

AlexandreEichenberger commented 1 year ago

it does seem clear that some of them may not be intended to be standardized any time soon since they provide very specific functionality.

Another way to look at this, once an operation (be it extension/external contributions) starts to be in several independent package, it starts to become a standard, regardless of whether it is or not. So I would still recommend to approach the operator SIG and maybe decide if some should be adopted, and others deprecated as not useful.

I was thinking of doing it in a staged approach -- first to decompose the non-standard operators into existing ONNX/krnl ops that can leverage the existing pipeline and then to optimize the operations that are computationally heavy.

That sounds like the right approach. Note that ideally, esp for the fused ops, we would want the compiler to pick up these patterns. To me, fused ops seems more key for runtime that have less of an overview of the whole graph.

kernhanda commented 1 year ago

Another way to look at this, once an operation (be it extension/external contributions) starts to be in several independent package, it starts to become a standard, regardless of whether it is or not. So I would still recommend to approach the operator SIG and maybe decide if some should be adopted, and others deprecated as not useful.

Yeah, I agree. Unfortunately, I have no input on this at the moment 😄

That sounds like the right approach. Note that ideally, esp for the fused ops, we would want the compiler to pick up these patterns. To me, fused ops seems more key for runtime that have less of an overview of the whole graph.

Yes, I think the correct approach would be to create a generic pattern for fused ops so that even decomposed ops would match against it and then have the fused op hook into that pattern.

As far as the implementation of such a change goes, I think the biggest would be the addition of the domain to operators and act accordingly. IIRC, the current operator matcher only looks at the name of the operator. A quicker way would be to prefix the domain to the operator name in case of non-standard domains and use that to do the matching.

AlexandreEichenberger commented 1 year ago

Can you make a proposal of what work would be needed to support this? When I look at this, it feels like we would need:

  1. modify the front end to generate these ops
  2. add a new dialect
  3. add a pass that convert them to generic onnx ops, hopefully before shape inference so that we don't need to do inference for it also.

Can you think of a better/lighter way to do this? Would a python pre-processing be too much of a hack?