onnx / keras-onnx

Convert tf.keras/Keras models to ONNX
Apache License 2.0
379 stars 110 forks source link

Behaviour of stateful RNN/LSTM/GRU in onnx #620

Open srikanthderebail opened 3 years ago

srikanthderebail commented 3 years ago

Hello everyone,

I am trying to export a Keras model that has a SimpleRNN layer with stateful=True. This ensures that the last state is kept between subsequent samples. However, when I export such a model using keras2onnx and evaluate with onnxruntime, the behaviour is different and it looks like the state is reset between subsequent samples.

Is there a way to have stateful behaviour in the converted onnx model ??

jiafatom commented 3 years ago

The onnx ops GRU, LSTM does not have stateful attribute, so the keras layer cannot get directly mapped to onnx op. It is relatively hard for the converter to construct this op. Feel like the best way is to propose in onnx repo asking to add this stateful attribute.

dkloving commented 3 years ago

One solution may be to add input to the model to pass state, and add output to model where model can return state. Then you can manage statefulness in your own code.

helmutbressler commented 3 years ago

Hello,

I have converted the sequential model presented in the Tensorflow/Keras textgeneration tutorial into a Keras functional model, which enables me to pass the RNN (in my case a GRU cell) state as input parameter and retrieve the RNN state as output parameter. In the end it is working now, but I have encountered two issues causing keras2onnx to fail:

 if output_state:
            output_h = operator.outputs[0].full_name
            apply_squeeze(scope, rnn_h, output_h, container)

Otherwise I get a "index out of range" error.

I'm not an experienced machine learning and python developer, so I may have done something totally stupid causing those problems. I'm using Keras 2.4 from Tensorflow 2.3.1, keras2onnx 1.7 and python 3.6.8.

The source code for reproducing that issue:

import tensorflow as tf
from tensorflow.python.keras import layers
import keras2onnx

def createModel(input_dim, batch_size, seq_len, embedding_dim=32, rnn_state_size=48, include_RNN_state=False):
    inputs = tf.keras.Input(shape=(seq_len,), dtype=tf.int32, name="Input", batch_size=batch_size)
    rnn_inputs = None

    if include_RNN_state:
        rnn_inputs = tf.keras.Input(shape=rnn_state_size, dtype=tf.float32, name="Input_RNN", batch_size=batch_size)

    embedding = layers.Embedding(input_dim=input_dim, output_dim=embedding_dim, batch_size=batch_size)
    x0 = embedding(inputs)

    rnn_outputs = None
    x1 = None
    if include_RNN_state:
        rnn = layers.GRU(rnn_state_size, batch_size=batch_size, stateful=True, return_state=True, return_sequences=True, unroll=True, recurrent_initializer='glorot_uniform')
        x1, gru_state = rnn(x0, initial_state=rnn_inputs)
        rnn_outputs = gru_state
    else:
        rnn = layers.GRU(rnn_state_size, batch_size=batch_size, stateful=True, return_sequences=True, recurrent_initializer='glorot_uniform')
        x1 = rnn(x0)

    dense = layers.Dense(input_dim)
    outputs = dense(x1)

    model = None
    if include_RNN_state:
        model = tf.keras.Model(inputs=[inputs, rnn_inputs], outputs=[outputs, rnn_outputs], name="char_predict_rnn")
    else:
        model = tf.keras.Model(inputs=inputs, outputs=outputs, name="char_predict_rnn")
    return model

model_without_rnnstate = createModel(10, 1, 1)
onnx_model = keras2onnx.convert_keras(model_without_rnnstate, model_without_rnnstate.name, debug_mode=True)

model_with_rnnstate = createModel(10, 1, 1, include_RNN_state=True)
onnx_model_stateful = keras2onnx.convert_keras(model_with_rnnstate, model_without_rnnstate.name, debug_mode=True)