spcl / daceml

A Data-Centric Compiler for Machine Learning
https://daceml.readthedocs.io
BSD 3-Clause "New" or "Revised" License
81 stars 15 forks source link

Custom ONNX node implementation not registering #129

Open vselhakim1337 opened 1 year ago

vselhakim1337 commented 1 year ago

I'm trying to externally add and register a custom implementation for an ONNX op. For the sake of context, the op in question is ONNXMul. I've tried following the code snippet in the documentation, and I've come up with this:

@op_implementation(op="Mul", name="myimpl")
class FPGAMul(ONNXForward):
    @staticmethod
    def forward_can_be_applied(node: ONNXOp, state: SDFGState, sdfg: SDFG) -> bool:
        ...

    @staticmethod
    def forward(node: ONNXOp, state: SDFGState, sdfg: SDFG) -> typing.Union[Node, SDFG]:
        ...

daceml.onnx.default_implementation = 'myimpl'

This piece of code resides in the same Python script that loads an ONNX model and expands the library nodes. Namely, somewhere later in the script I have

    model = onnx.load(model_path)
    dace_model = daceml.onnx.ONNXModel(name, model)
    print ('ONNX model loaded...')
    dace_model.sdfg.expand_library_nodes()

However this results in an error:

Traceback (most recent call last):
  File "load.py", line 114, in predict_daceml
    dace_model.sdfg.expand_library_nodes()
  File ".../dace/sdfg/sdfg.py", line 2559, in expand_library_nodes
    impl_name = node.expand(self, state)
  File ".../dace/sdfg/nodes.py", line 1269, in expand
    raise KeyError("Unknown implementation for node {}: {}".format(type(self).__name__, implementation))
KeyError: 'Unknown implementation for node ONNXMul: myimpl'

Can you tell me why my implementation is not registering, and how to fix this?