Open Taha-Bahadori opened 5 years ago
handlers/backend/gather.py
import tensorflow as tf
import numpy as np
from onnx_tf.handlers.backend_handler import BackendHandler
from onnx_tf.handlers.handler import onnx_op
from onnx_tf.handlers.handler import tf_func
@onnx_op("Gather")
@tf_func(tf.gather)
class Gather(BackendHandler):
@classmethod
def version_1(cls, node, **kwargs):
if node.inputs[1] in kwargs["const_dict"]:
input_1 = kwargs["tensor_dict"][node.inputs[0]]
input_2 = kwargs["const_dict"][node.inputs[1]]
ndim = input_2.ndim
if ndim == 0:
input_2 = np.asarray([input_2])
if any(input_2 < 0):
dim = input_1.shape.dims[node.attrs["axis"]].value
new_ind = [i if i >= 0 else i + dim for i in input_2]
if ndim == 0:
new_ind = new_ind[0]
input_2 = tf.constant(new_ind, dtype=input_2.dtype)
return [cls.make_tensor_from_onnx_node(node, inputs=[input_1, input_2], **kwargs)]
return [cls.make_tensor_from_onnx_node(node, **kwargs)]
handlers/backend/constant.py
from onnx import numpy_helper
import tensorflow as tf
from onnx_tf.handlers.backend_handler import BackendHandler
from onnx_tf.handlers.handler import onnx_op
from onnx_tf.handlers.handler import tf_func
from onnx_tf.common import data_type
@onnx_op("Constant")
@tf_func(tf.constant)
class Constant(BackendHandler):
@classmethod
def _common(cls, node, **kwargs):
attr_value = node.attrs["value"]
dtype = data_type.onnx2tf(attr_value.data_type)
value = numpy_helper.to_array(attr_value)
kwargs["const_dict"][node.outputs[0]] = value
return [
cls.make_tensor_from_onnx_node(
node, inputs=[value], attrs={"dtype": dtype})
]
@classmethod
def version_1(cls, node, **kwargs):
return cls._common(node, **kwargs)
@classmethod
def version_9(cls, node, **kwargs):
return cls._common(node, **kwargs)
handlers/backend/slice.py
@tf_func(tf.slice)
handlers/backend.py
L126
const_dict = dict()
for node in graph_def.node:
onnx_node = OnnxNode(node)
output_ops = cls._onnx_node_to_tensorflow_op(
onnx_node, tensor_dict, handlers, opset=opset, strict=strict, const_dict=const_dict)
curr_node_output_map = dict(zip(onnx_node.outputs, output_ops))
tensor_dict.update(curr_node_output_map)
L203
@classmethod
def _onnx_node_to_tensorflow_op(cls,
node,
tensor_dict,
handlers=None,
opset=None,
strict=True,
const_dict=None):
"""
Convert onnx node to tensorflow op.
Args:
node: Onnx node object.
tensor_dict: Tensor dict of graph.
opset: Opset version of the operator set. Default 0 means using latest version.
strict: whether to enforce semantic equivalence between the original model
and the converted tensorflow model, defaults to True (yes, enforce semantic equivalence).
Changing to False is strongly discouraged.
Returns:
Tensorflow op
"""
handlers = handlers or cls._get_handlers(opset)
handler = handlers[node.domain].get(node.op_type, None)
if handler:
return handler.handle(node, tensor_dict=tensor_dict, strict=strict, const_dict=const_dict)
else:
exception.OP_UNIMPLEMENTED_EXCEPT(node.op_type)
Describe the bug
Trying to import an LSTM model from ONNX to tensorflow, we get a dimension mismatch error.
To Reproduce
Here is the error:
ONNX model file
https://www.dropbox.com/s/yuc6mleiuw7tlp8/test.onnx.gz?dl=0
Python, ONNX, ONNX-TF, Tensorflow version