Closed prabhuiitdhn closed 3 years ago
It would be much appreciated If anyone can make me understand this issue. I am okay to modify the code. Thank you.
onnx_tf only support operators in https://github.com/onnx/onnx-tensorflow/blob/master/doc/support_status.md. _DCNv2 is a custom operator and is not support by onnx-tf. To support a new operator, please follow https://github.com/onnx/onnx-tensorflow/blob/master/doc/IMPLEMENTING_NEW_OP.md. Hope it can help.
Hi @chudegao: Thanks for replying. Yes, I followed the same instructions that you have shared. But still, there is no improvement, and I followed this to add operator in onnx to convert successfully torch model to onnx model.
I added a handler to /onnx_tf/handlers/backend/ named as "_DCNv2.py" using:
import tensorflow as tf
from onnx_tf.handlers.backend_handler import BackendHandler
from onnx_tf.handlers.handler import onnx_op
from onnx_tf.handlers.handler import tf_func
# from .math_mixin import BasicMathMixin
@onnx_op("_DCNv2")
# @tf.func
class _DCNv2(BackendHandler):
@classmethod
def version_9(cls, node, **kwargs):
return [cls.make_tensor_from_onnx_node(node, **kwargs)]
@classmethod
def version_10(cls, node, **kwargs):
return [cls.make_tensor_from_onnx_node(node, **kwargs)]
Executed the command:
gen_opset.py gen_status.py -m.
Please correct me If I am doing wrong. Thank you.
First make sure the code you updated is where the onnx_tf is installed. (chudg) [root@haswell01 onnx_tf]# pip list|grep onnx-tf onnx-tf 1.8.0 /chudg/git/onnx-tensorflow (chudg) [root@haswell01 onnx_tf]# pwd /chudg/git/onnx-tensorflow/onnx_tf (chudg) [root@haswell01 onnx_tf]#
After run 'python gen_opset.py .', you should can get patchs as below:
(chudg) [root@haswell01 onnx_tf]# git diff diff --git a/onnx_tf/opset_version.py b/onnx_tf/opset_version.py index 86e9f87..1670e51 100644 --- a/onnx_tf/opset_version.py +++ b/onnx_tf/opset_version.py @@ -15,7 +15,6 @@ backend_opset_version = { 'Atanh': [9], 'AveragePool': [1, 7, 10, 11], 'BatchNormalization': [1, 6, 7, 9], 'Binarizer': [], 'BitShift': [11], 'Cast': [1, 6, 9, 13], @@ -187,7 +186,8 @@ backend_opset_version = { 'Upsample': [7, 9], 'Where': [9], 'Xor': [1, 7],
from onnx_tf.common.handler_helper import get_all_backend_handlers a=get_all_backend_handlers({}) a['']['_DCNv2'] <class 'onnx_tf.handlers.backend._DCNv2._DCNv2'>
(chudg) [root@haswell01 onnx_tf]#
If you can provide the onnx file, I can help debug.
I guess you set domain for your custom op. onnx-tf handler should also set same domain as the default domain is ''. Please try to add an attribute for your handler: DOMAIN=xxx. @chinhuang007 seems there's no doc for this. Correct me if I's wrong.
Correct. There is no doc for handling custom ops. The domain would be the key to differentiate from an ONNX op since all custom ops will have specific domain names.
@chinhuang007: I have followed the same steps but unable to fix it. I am sharing onnx folder Please check it for debug. I think we are too close to fix it. I am looking for strong support. Thank you for replying.
@chudegao Thanks for much more clarification. I am not sure whether I have added DOMAIN or not. This is how I have created a custom operator and added to ONNX and It works on model conversion. Please find the models here for much more clearance.
class _DCNv2(Function):
@staticmethod
def symbolic(g, input, offset, mask, weight, bias, stride, padding, dilation, deformable_groups):
stride = _pair(stride)
padding = _pair(padding)
dilation = _pair(dilation)
return g.op("custom_domain::_DCNv2", input, offset, mask, weight, bias, stride_i=stride, padding_i=padding,
dilation_i=dilation, deformable_groups_i=deformable_groups, )
Can you please give more insights that how we add an attribute for handler? Thank you.
I add domain as below: class _DCNv2(BackendHandler): DOMAIN = 'org.pytorch.custom_domain'
I think there are still two issues.
(chudg) [root@haswell01 onnx-tensorflow]# git diff diff --git a/onnx_tf/gen_opset.py b/onnx_tf/gen_opset.py index 8358b1c..6ebb93e 100755 --- a/onnx_tf/gen_opset.py +++ b/onnx_tf/gen_opset.py \@@ -20,7 +20,8 @@ def main(): backend_opset_dict[op_name] = []
backend_onnx_coverage, backend_experimental_op = get_backend_coverage() - backend_opset_dict.update(backend_onnx_coverage.get(defs.ONNX_DOMAIN, {})) + for domain in backend_onnx_coverage.keys(): + backend_opset_dict.update(backend_onnx_coverage.get(domain, {})) backend_ps_dict = get_backend_partial_support_detail()
model=onnx.load('/root/model_mot.onnx') model.opset_import [version: 10 , domain: "org.pytorch.custom_domain" version: 0 ]
def version_1(cls, node, **kwargs):
return [cls.make_tensor_from_onnx_node(node, **kwargs)]
The error will change as below as the tf.func is not defined. RuntimeError: No Tensorflow function is given.
@chudegao: I added an attribute for handler: DOMAIN = 'org.pytorch.custom_domain' as per you mentioned, and I followed the same https://github.com/onnx/onnx-tensorflow/blob/master/doc/IMPLEMENTING_NEW_OP.md to implement new op. But I am still unable to fix it.
This is the modified code:
import tensorflow as tf
from onnx_tf.handlers.backend_handler import BackendHandler
from onnx_tf.handlers.handler import onnx_op
from onnx_tf.handlers.handler import tf_func
# from .math_mixin import BasicMathMixin
@onnx_op("_DCNv2")
# @tf.func
class _DCNv2(BackendHandler):
DOMAIN = 'org.pytorch.custom_domain'
@classmethod
def version_1(cls, node, **kwargs):
return [cls.make_tensor_from_onnx_node(node, **kwargs)]
@classmethod
def version_9(cls, node, **kwargs):
return [cls.make_tensor_from_onnx_node(node, **kwargs)]
@classmethod
def version_10(cls, node, **kwargs):
return [cls.make_tensor_from_onnx_node(node, **kwargs)]
After implementing operator I tried to check the status of '_DCNv2' handler and It seems _DCNv2 is not added.
>>> import onnx_tf
2021-07-09 19:49:43.199937: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda-10.0/lib64:/usr/local/cuda-10.0/lib
2021-07-09 19:49:43.199971: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.
>>> from onnx_tf.common.handler_helper import get_all_backend_handles
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
ImportError: cannot import name 'get_all_backend_handles' from 'onnx_tf.common.handler_helper' (/home/uib43225/test_onnx/lib/python3.8/site-packages/onnx_tf/common/handler_helper.py)
>>> from onnx_tf.common.handler_helper import get_all_backend_handlers
>>> a = get_all_backend_handlers({})
>>> a['']['_DCNv2']
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
KeyError: '_DCNv2'
Also, I am wondering If there is an issue with Cuda or version incompatibility: onnx: 1.9.0 onnx_tf: 1.8.0 protoc: 3.11.3 tensorflow : 2.5.0
Thank you.
Hi @chinhuang007: If you debugged the attached model, Please help me to understand and fix the issue. Thank you.
Hi @chudegao Will you please help me to fix it? I think we are too close to fix it, We just have to see even after following all the steps correctly why onnx_tensorflow is unable to implement a new operator? After everything running successfully I am still getting this issue.
>>> from onnx_tf.common.handler_helper import get_all_backend_handlers
>>> a = get_all_backend_handlers({})
>>> a['']['_DCNv2']
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
KeyError: '_DCNv2'
Please use your domain as the key.
a['org.pytorch.custom_domain'] {'_DCNv2': <class 'onnx_tf.handlers.backend._DCNv2._DCNv2'>}
@chudegao: I am not sure whether It has still the same issue. I followed the steps again but no improvement. Do you think it is because of CUDA error?
import onnx_tf 2021-07-09 19:49:43.199937: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda-10.0/lib64:/usr/local/cuda-10.0/lib 2021-07-09 19:49:43.199971: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.
from onnx_tf.common.handler_helper import get_all_backend_handles Traceback (most recent call last): a['org.pytorch.custom_domain'] Traceback (most recent call last): File "
", line 1, in KeyError: 'org.pytorch.custom_domain'
It's not related to cuda. Please make sure you apply onnx-tf patch I mentioned above. I submit a PR for this: https://github.com/onnx/onnx-tensorflow/pull/945
@chudegao: Thank you for quick reply but Will you please elaborate it? I am unable to understand it.
@chudegao: Thank you for quick reply but Will you please elaborate it? I am unable to understand it.
I think onnx-tf have one issue to support custom op with domain name. To fix it, you can apply the patch in pr 945(replice file as https://github.com/onnx/onnx-tensorflow/pull/945/files). Then run "python gen_opset.py . " command to re-generate the opset file(_DCNv2 will be added to opset_version.py). You should can get the right result.
@chudegao: Thank you so much for the clarification. I fixed the issue and Now I am able to see exactly the same:
a['org.pytorch.custom_domain'] {'_DCNv2': <class 'onnx_tf.handlers.backend._DCNv2._DCNv2'>}
But how would I fix 'No tensorflow function is given'?
/home/uib43225/new_onnx_tf/onnx/onnx-tensorflow/onnx_tf/handlers/backend_handler.py:139 make_tensor_from_onnx_node * raise RuntimeError("No Tensorflow function is given.") RuntimeError: No Tensorflow function is given.
Do I need to add any tensorflow function as default to fix because there is no function is being called? Or Can I use other function in the place of 'make_tensor_from_onnx_node(node, **kwargs)' to convert .onnx model to .pb successfully?
make_tensor_from_onnx_node is an internal api , it will call tf_fun defined in the handler(where you comment ) to implement the operator. If there's a tensorflow api you can leverage for your custom op, you can use it(most operator use this, you can reference. e.g. add.py). Otherwise you can implement the op by youself(you can reference hardswish.py).
@chudegao: Thank you so much for your support, Yes, we did it, Hurray. Much appreciated.
@chudegao: Thank you so much for your support, Yes, we did it, Hurray. Much appreciated.
You are welcome.
Hi @prabhuiitdhn , could you share the op you defined? thanks!
Hi: I have onnx model (custom operator: _DCNv2 added; onnx: 1.9.0, torch: 1.2.0) and Trying to convert to Tensorflow followed by https://github.com/onnx/onnx-tensorflow installation instruction. But I am having the following error.
onnx_tf: 1.8.0 tensorflow: 2.4.1
Error:
Full Error: Could be the issue with tensorflow installation