onnx / onnx-tensorflow

Tensorflow Backend for ONNX
Other
1.28k stars 296 forks source link

ValueError: Dimensions must be equal for LSTMs #436

Open Taha-Bahadori opened 5 years ago

Taha-Bahadori commented 5 years ago

Describe the bug

Trying to import an LSTM model from ONNX to tensorflow, we get a dimension mismatch error.

To Reproduce

import onnx
import torch
import torch.nn as nn
from onnx_tf.backend import prepare

class RNN(nn.Module):
    def __init__(self, setting):
        super(RNN, self).__init__()
        self.n_rnn = setting['n_rnn_layer']
        self.rnn_size = setting['rnn_dim']

        self.emb = nn.Embedding(setting['in_dim'], self.rnn_size, padding_idx=0)
        self.hidden0 = torch.zeros(self.n_rnn, 1, self.rnn_size).cuda()
        self.hidden1 = torch.zeros(self.n_rnn, 1, self.rnn_size).cuda()
        self.lstm = nn.LSTM(self.rnn_size, self.rnn_size, num_layers=self.n_rnn)
        self.fc = nn.Linear(self.rnn_size, setting['out_dim'])

    def forward(self, matrix):
        tensor = self.emb(matrix)
        tensor = tensor.transpose(dim0=0, dim1=1)
        hid0s = (self.hidden0.repeat(1, tensor.size(1), 1), self.hidden1.repeat(1, tensor.size(1), 1))
        output = self.lstm(tensor, hid0s)[0][-1, :, :]
        return self.fc(output)

# Define a model in PyTorch
setting = {'n_rnn_layer': 1, 'out_dim': 5, 'in_dim': 29, 'rnn_dim': 64}
model = RNN(setting).cuda()
bsize, seqlen = 30, 17
data = torch.randint(0, setting['in_dim'], (bsize, seqlen)).long().cuda()
print('PyTorch output:')
print(model(data))

# Export it to ONNX
torch.onnx.export(model, data, "lstm.onnx")

# Import the model
tfmodel = onnx.load('lstm.onnx')
tf_rep = prepare(tfmodel)
tf_out = tf_rep.run(data.cpu().data.numpy())

Here is the error:

Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py", line 1659, in _create_c_op
    c_op = c_api.TF_FinishOperation(op_desc)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Dimensions must be equal, but are 32 and 64 for 'LSTM_c2ecd878/rnn/while/rnn/multi_rnn_cell/cell_0/lstm_cell/mul' (op: 'Mul') with input shapes: [30,3
2], [30,64].

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "exp_onnx.py", line 44, in <module>
    tf_rep = prepare(tfmodel)
  File "/home/ubuntu/onnx-tensorflow/onnx_tf/backend.py", line 55, in prepare
    return cls.onnx_model_to_tensorflow_rep(model, strict)
  File "/home/ubuntu/onnx-tensorflow/onnx_tf/backend.py", line 75, in onnx_model_to_tensorflow_rep
    return cls._onnx_graph_to_tensorflow_rep(model.graph, opset_import, strict)
  File "/home/ubuntu/onnx-tensorflow/onnx_tf/backend.py", line 129, in _onnx_graph_to_tensorflow_rep
    onnx_node, tensor_dict, handlers, opset=opset, strict=strict)
  File "/home/ubuntu/onnx-tensorflow/onnx_tf/backend.py", line 224, in _onnx_node_to_tensorflow_op
    return handler.handle(node, tensor_dict=tensor_dict, strict=strict)
  File "/home/ubuntu/onnx-tensorflow/onnx_tf/handlers/handler.py", line 60, in handle
    return ver_handle(node, **kwargs)
  File "/home/ubuntu/onnx-tensorflow/onnx_tf/handlers/backend/lstm.py", line 212, in version_7
    return cls._common(node, **kwargs)
  File "/home/ubuntu/onnx-tensorflow/onnx_tf/handlers/backend/lstm.py", line 182, in _common
    rnn_kwargs, tf_activations, direction)
  File "/home/ubuntu/onnx-tensorflow/onnx_tf/handlers/backend/rnn_mixin.py", line 38, in rnn
    outputs, states = tf.nn.dynamic_rnn(cell_fw, x, **rnn_kwargs)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/util/deprecation.py", line 324, in new_func
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/rnn.py", line 671, in dynamic_rnn
    dtype=dtype)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/rnn.py", line 879, in _dynamic_rnn_loop
    swap_memory=swap_memory)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/control_flow_ops.py", line 3556, in while_loop
    return_same_structure)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/control_flow_ops.py", line 3087, in BuildLoop
    pred, body, original_loop_vars, loop_vars, shape_invariants)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/control_flow_ops.py", line 3022, in _BuildLoop
    body_result = body(*packed_vars_for_body)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/control_flow_ops.py", line 3525, in <lambda>
    body = lambda i, lv: (i + 1, orig_body(*lv))
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/rnn.py", line 847, in _time_step
    (output, new_state) = call_cell()
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/rnn.py", line 833, in <lambda>
    call_cell = lambda: cell(input_t, state)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/rnn_cell_impl.py", line 234, in __call__
    return super(RNNCell, self).__call__(inputs, state)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/layers/base.py", line 530, in __call__
    outputs = super(Layer, self).__call__(inputs, *args, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/base_layer.py", line 554, in __call__
    outputs = self.call(inputs, *args, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/rnn_cell_impl.py", line 1515, in call
    cur_inp, new_state = cell(cur_inp, cur_state)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/rnn_cell_impl.py", line 371, in __call__
    *args, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/layers/base.py", line 530, in __call__
    outputs = super(Layer, self).__call__(inputs, *args, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/base_layer.py", line 554, in __call__
    outputs = self.call(inputs, *args, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/rnn_cell_impl.py", line 1007, in call
    c = (sigmoid(f + self._forget_bias) * c_prev + sigmoid(i) *
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/math_ops.py", line 812, in binary_op_wrapper
    return func(x, y, name=name)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/math_ops.py", line 1078, in _mul_dispatch
    return gen_math_ops.mul(x, y, name=name)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/gen_math_ops.py", line 5860, in mul
    "Mul", x=x, y=y, name=name)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/op_def_library.py", line 788, in _apply_op_helper
    op_def=op_def)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/util/deprecation.py", line 507, in new_func
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py", line 3300, in create_op
    op_def=op_def)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py", line 1823, in __init__
    control_input_ops)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py", line 1662, in _create_c_op
    raise ValueError(str(e))
ValueError: Dimensions must be equal, but are 32 and 64 for 'LSTM_f0f3b67c/rnn/while/rnn/multi_rnn_cell/cell_0/lstm_cell/mul' (op: 'Mul') with input shapes: [30,32], [30,64].

ONNX model file

https://www.dropbox.com/s/yuc6mleiuw7tlp8/test.onnx.gz?dl=0

Python, ONNX, ONNX-TF, Tensorflow version

Python 3.6.7
onnx 1.5.0
onnx_tf: installed from the source on June 5, 2019
pytorch 1.1.0
TF 1.13.1
fumihwh commented 5 years ago

handlers/backend/gather.py

import tensorflow as tf
import numpy as np

from onnx_tf.handlers.backend_handler import BackendHandler
from onnx_tf.handlers.handler import onnx_op
from onnx_tf.handlers.handler import tf_func

@onnx_op("Gather")
@tf_func(tf.gather)
class Gather(BackendHandler):

  @classmethod
  def version_1(cls, node, **kwargs):
    if node.inputs[1] in kwargs["const_dict"]:
      input_1 = kwargs["tensor_dict"][node.inputs[0]]
      input_2 = kwargs["const_dict"][node.inputs[1]]
      ndim = input_2.ndim
      if ndim == 0:
        input_2 = np.asarray([input_2])
      if any(input_2 < 0):
        dim = input_1.shape.dims[node.attrs["axis"]].value
        new_ind = [i if i >= 0 else i + dim for i in input_2]
        if ndim == 0:
          new_ind = new_ind[0]
        input_2 = tf.constant(new_ind, dtype=input_2.dtype)
        return [cls.make_tensor_from_onnx_node(node, inputs=[input_1, input_2], **kwargs)]
    return [cls.make_tensor_from_onnx_node(node, **kwargs)]

handlers/backend/constant.py

from onnx import numpy_helper
import tensorflow as tf

from onnx_tf.handlers.backend_handler import BackendHandler
from onnx_tf.handlers.handler import onnx_op
from onnx_tf.handlers.handler import tf_func
from onnx_tf.common import data_type

@onnx_op("Constant")
@tf_func(tf.constant)
class Constant(BackendHandler):

  @classmethod
  def _common(cls, node, **kwargs):
    attr_value = node.attrs["value"]
    dtype = data_type.onnx2tf(attr_value.data_type)
    value = numpy_helper.to_array(attr_value)
    kwargs["const_dict"][node.outputs[0]] = value
    return [
        cls.make_tensor_from_onnx_node(
            node, inputs=[value], attrs={"dtype": dtype})
    ]

  @classmethod
  def version_1(cls, node, **kwargs):
    return cls._common(node, **kwargs)

  @classmethod
  def version_9(cls, node, **kwargs):
    return cls._common(node, **kwargs)

handlers/backend/slice.py

@tf_func(tf.slice)

handlers/backend.py

L126
      const_dict = dict()

      for node in graph_def.node:
        onnx_node = OnnxNode(node)
        output_ops = cls._onnx_node_to_tensorflow_op(
            onnx_node, tensor_dict, handlers, opset=opset, strict=strict, const_dict=const_dict)
        curr_node_output_map = dict(zip(onnx_node.outputs, output_ops))
        tensor_dict.update(curr_node_output_map)

L203
  @classmethod
  def _onnx_node_to_tensorflow_op(cls,
                                  node,
                                  tensor_dict,
                                  handlers=None,
                                  opset=None,
                                  strict=True,
                                  const_dict=None):
    """
    Convert onnx node to tensorflow op.

    Args:
      node: Onnx node object.
      tensor_dict: Tensor dict of graph.
      opset: Opset version of the operator set. Default 0 means using latest version.
      strict: whether to enforce semantic equivalence between the original model
        and the converted tensorflow model, defaults to True (yes, enforce semantic equivalence).
        Changing to False is strongly discouraged.
    Returns:
      Tensorflow op
    """
    handlers = handlers or cls._get_handlers(opset)
    handler = handlers[node.domain].get(node.op_type, None)
    if handler:
      return handler.handle(node, tensor_dict=tensor_dict, strict=strict, const_dict=const_dict)
    else:
      exception.OP_UNIMPLEMENTED_EXCEPT(node.op_type)