onnx / onnx-tensorflow

Tensorflow Backend for ONNX
Other
1.27k stars 297 forks source link

Bidir LSTM `activations` attribute is optional #327

Open samgd opened 5 years ago

samgd commented 5 years ago

Describe the bug

The ONNX LSTM operator states that the activations attribute is optional. The LSTM backend implemenation relies on this property being present when using a bidirectional LSTM.

The implementation initialises tf_activations = [tf.nn.tanh] and only appends an additional activation function if "activations" in node.attrs:" and if num_directions == 2: here:

    tf_activations = [tf.nn.tanh]
    if "activations" in node.attrs:
      activations = list(map(lambda x: x.lower(), node.attrs["activations"]))
      activation_alpha = node.attrs.get("activation_alpha", [None] * 6)
      activation_beta = node.attrs.get("activation_beta", [None] * 6)
      tf_activations = [
          cls.rnn_get_activation(activations[1], activation_alpha[1],
                                 activation_beta[1])
      ]
      if num_directions == 2:
        tf_activations.append(
            cls.rnn_get_activation(activations[4], activation_alpha[4],
                                   activation_beta[4]))

tf_activations is passed to the rnn method here:

      outputs, states = cls.rnn(x, tf.nn.rnn_cell.LSTMCell, cell_kwargs,
                                rnn_kwargs, tf_activations, direction)

The rnn method relies on there being a second activations value present if direction == "bidirectional":, where activations is the tf_activations from above here:

    if direction == "bidirectional":
      cell_kwargs["activation"] = activations[1]
      rnn_cell_bw = [cell_class(**cell_kwargs)]
      cell_bw = tf.nn.rnn_cell.MultiRNNCell([rnn_cell_bw])

If the optional attribute is not present, the activations list access at this index fails and causes a:

IndexError: list index out of range

To Reproduce

My PyTorch ONNX export seems to include a bunch of strings relating to my environment that I don't want to disclose. Below is a very simple PyTorch script to create an ONNX model that causes this issue:

import torch
from torch.autograd import Variable

dummy_input = Variable(torch.randn(30, 1, 10))

model = torch.nn.LSTM(input_size=10,
                      hidden_size=10,
                      bidirectional=True)
model.eval()

torch.onnx.export(model, dummy_input, '/tmp/model.onnx', verbose=False)
fumihwh commented 5 years ago

@samgd Could you create a PR to fix this?

chatzikon commented 5 years ago

I think that the problem could be resolved by replacing: tf_activations = [tf.nn.tanh]

with: if num_directions == 2: tf_activations = [tf.nn.tanh,tf.nn.tanh] else: tf_activations = [tf.nn.tanh]