Describe the bug
When using a tf.keras.layers.LSTM layer that enables masking behavior and providing a post-padded input, TF inference skips the zeros time-steps. When return_sequences=True, the first LSTM output has the skipped hidden outputs carrying over the value of its previous time-step, see example in "To Reproduce". This is an expected behavior.
The corresponding ONNX model would deal with such mask-enabled input-post-padded LSTM with the seq_len field, to specify which time-steps to skip per batch. This was done in keras2onnx, and is being worked on in tf2onnx. However ORT inference skips time-steps by putting zeros to the corresponding time-steps in Y, instead of having them take the same value as their previous steps.
This inference result inconsistency created confusion/wrong results when converting a TF-trained model to ONNX and running inference with ORT. We are aware there might be good reason for ORT's behavior, and we're curious to learn why. Please shed some light on this issue.
System information
OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Linux
ONNX Runtime installed from (source or binary): source
ONNX Runtime version: 1.11.1
Python version: 3.7.12
To Reproduce
Below is a sample TF model to reproduce the issue.
x = tf.keras.layers.Input(shape=(4, ), dtype="float32")
initializer = tf.keras.initializers.RandomNormal(mean=0., stddev=1., seed=0)
processed_x = tf.keras.layers.Embedding(5, 5, mask_zero=True, embeddings_initializer=initializer)(x)
outputs = tf.keras.layers.LSTM(
3,
kernel_initializer=initializer,
recurrent_initializer=initializer,
return_sequences=True, # to output full intermediate hidden outputs, concatenated across time-steps
return_state=True.
)(processed_x)
model = tf.keras.Model(inputs=x, outputs=outputs)
model.save("embedding_masked_lstm.h5")
After converted with keras2onnx, the ONNX is shown in below screenshot. keras2onnx is deprecated, but the masked RNN support is not yet available in tf2onnx. We're using keras2onnx here only to create a comparable ONNX to demonstrate the LSTM inference result confusion. (also attached: embedding_lstm_onnx.tar.gz)
With the above models and model inputs of shape (batch_size, timestep), i.e. (3, 4) in this example. This will imply seq_len=[3, 4, 4] for the ONNX LSTM.
[[1. 1. 1. 0.]
[1. 1. 1. 1.]
[1. 1. 1. 1.]]
TF inference results for the first model output (i.e. LSTM's Y output) is:
Expected behavior
Expect the same behavior as Tensorflow inference, where the intermediate hidden outputs of skipped time-steps carries over previous step's value, instead of having zeros. We'd like to know the reason for ORT's behavior.
If there is no input for some steps due to the sequence length, the output from those steps should be ignored. Given that can you explain why it matters whether a zero or the previous hidden value is used?
Describe the bug When using a
tf.keras.layers.LSTM
layer that enables masking behavior and providing a post-padded input, TF inference skips the zeros time-steps. Whenreturn_sequences=True
, the first LSTM output has the skipped hidden outputs carrying over the value of its previous time-step, see example in "To Reproduce". This is an expected behavior.The corresponding ONNX model would deal with such mask-enabled input-post-padded LSTM with the
seq_len
field, to specify which time-steps to skip per batch. This was done in keras2onnx, and is being worked on in tf2onnx. However ORT inference skips time-steps by puttingzeros
to the corresponding time-steps inY
, instead of having them take the same value as their previous steps.This inference result inconsistency created confusion/wrong results when converting a TF-trained model to ONNX and running inference with ORT. We are aware there might be good reason for ORT's behavior, and we're curious to learn why. Please shed some light on this issue.
System information
To Reproduce Below is a sample TF model to reproduce the issue.
After converted with keras2onnx, the ONNX is shown in below screenshot. keras2onnx is deprecated, but the masked RNN support is not yet available in tf2onnx. We're using keras2onnx here only to create a comparable ONNX to demonstrate the LSTM inference result confusion. (also attached: embedding_lstm_onnx.tar.gz)
With the above models and model inputs of shape (batch_size, timestep), i.e.
(3, 4)
in this example. This will implyseq_len=[3, 4, 4]
for the ONNX LSTM.TF inference results for the first model output (i.e. LSTM's
Y
output) is:Here's the ORT inference results on the first model output
lstm
(transposed LSTM outputY
):Expected behavior Expect the same behavior as Tensorflow inference, where the intermediate hidden outputs of skipped time-steps carries over previous step's value, instead of having zeros. We'd like to know the reason for ORT's behavior.