microsoft / onnxruntime

ONNX Runtime: cross-platform, high performance ML inferencing and training accelerator
https://onnxruntime.ai
MIT License
14.74k stars 2.94k forks source link

LSTM Y output is inconsistent with TF inference result when seq_len is effective #12492

Open q-ycong-p opened 2 years ago

q-ycong-p commented 2 years ago

Describe the bug When using a tf.keras.layers.LSTM layer that enables masking behavior and providing a post-padded input, TF inference skips the zeros time-steps. When return_sequences=True, the first LSTM output has the skipped hidden outputs carrying over the value of its previous time-step, see example in "To Reproduce". This is an expected behavior.

The corresponding ONNX model would deal with such mask-enabled input-post-padded LSTM with the seq_len field, to specify which time-steps to skip per batch. This was done in keras2onnx, and is being worked on in tf2onnx. However ORT inference skips time-steps by putting zeros to the corresponding time-steps in Y, instead of having them take the same value as their previous steps.

This inference result inconsistency created confusion/wrong results when converting a TF-trained model to ONNX and running inference with ORT. We are aware there might be good reason for ORT's behavior, and we're curious to learn why. Please shed some light on this issue.

System information

To Reproduce Below is a sample TF model to reproduce the issue.

x = tf.keras.layers.Input(shape=(4, ), dtype="float32")
initializer = tf.keras.initializers.RandomNormal(mean=0., stddev=1., seed=0)
processed_x = tf.keras.layers.Embedding(5, 5, mask_zero=True, embeddings_initializer=initializer)(x)
outputs = tf.keras.layers.LSTM(
    3,
    kernel_initializer=initializer,
    recurrent_initializer=initializer,
    return_sequences=True,      # to output full intermediate hidden outputs, concatenated across time-steps
    return_state=True.               
)(processed_x)

model = tf.keras.Model(inputs=x, outputs=outputs)
model.save("embedding_masked_lstm.h5")

After converted with keras2onnx, the ONNX is shown in below screenshot. keras2onnx is deprecated, but the masked RNN support is not yet available in tf2onnx. We're using keras2onnx here only to create a comparable ONNX to demonstrate the LSTM inference result confusion. (also attached: embedding_lstm_onnx.tar.gz)

Screen Shot 2022-08-05 at 9 10 10 AM

With the above models and model inputs of shape (batch_size, timestep), i.e. (3, 4) in this example. This will imply seq_len=[3, 4, 4] for the ONNX LSTM.

[[1. 1. 1. 0.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]]

TF inference results for the first model output (i.e. LSTM's Y output) is:

[[[ 0.11850349, -0.09868674, -0.23665339],
   [ 0.1639704 , -0.08429984, -0.3235214 ],
   [ 0.19384949, -0.0426457 , -0.33991447],
   [ 0.19384949, -0.0426457 , -0.33991447]], # note the hidden output of skipped time-step carries over value from previous step

  [[ 0.11850349, -0.09868674, -0.23665339],
   [ 0.1639704 , -0.08429984, -0.3235214 ],
   [ 0.19384949, -0.0426457 , -0.33991447],
   [ 0.21535467,  0.00041166, -0.34244764]],

  [[ 0.11850349, -0.09868674, -0.23665339],
   [ 0.1639704 , -0.08429984, -0.3235214 ],
    [ 0.19384949, -0.0426457 , -0.33991447],
    [ 0.21535467,  0.00041166, -0.34244764]]]

Here's the ORT inference results on the first model output lstm (transposed LSTM output Y):

[[[ 0.11850349 -0.09868674 -0.23665339]
  [ 0.1639704  -0.08429984 -0.3235214 ]
  [ 0.19384949 -0.0426457  -0.33991447]
  [ 0.          0.          0.        ]]   # note the skipped time-step has zeros for hidden outputs

 [[ 0.11850349 -0.09868674 -0.23665339]
  [ 0.1639704  -0.08429984 -0.3235214 ]
  [ 0.19384949 -0.0426457  -0.33991447]
  [ 0.21535467  0.00041166 -0.34244764]]

 [[ 0.11850349 -0.09868674 -0.23665339]
  [ 0.1639704  -0.08429984 -0.3235214 ]
  [ 0.19384949 -0.0426457  -0.33991447]
  [ 0.21535467  0.00041166 -0.34244764]]]

Expected behavior Expect the same behavior as Tensorflow inference, where the intermediate hidden outputs of skipped time-steps carries over previous step's value, instead of having zeros. We'd like to know the reason for ORT's behavior.

skottmckay commented 2 years ago

If there is no input for some steps due to the sequence length, the output from those steps should be ignored. Given that can you explain why it matters whether a zero or the previous hidden value is used?