onnx / tensorflow-onnx

Convert TensorFlow, Keras, Tensorflow.js and Tflite models to ONNX
Apache License 2.0
2.29k stars 432 forks source link

Keras LSTM not converted correctly when precedent Embedding layer specifies mask_zero=True #1871

Open q-ycong-p opened 2 years ago

q-ycong-p commented 2 years ago

Describe the bug When an tf.keras.layers.Embedding with attribute mask_zero=True attribute precede LSTM layer, the LSTM is converted into loops instead of LSTM op.

System information

To Reproduce

# minimal example of problematic keras model
x = tf.keras.layers.Input(shape=(4,), dtype="int32")
e = tf.keras.layers.Embedding(5, 5, mask_zero=True)(x)
rnn = tf.keras.layers.LSTM(3, return_sequences=True)(e)[0]
model = tf.keras.Model(inputs=x, outputs=rnn)

# converted onnx will have loops instead of LSTM op
onnx_model, _ = tf2onnx.convert.from_keras(model)
onnx.save(onnx_model, "lstm_masking_zero.onnx")

Screenshots

Original keras model to be converted Unsuccessful conversion of lstm into loops

Additional context I've tried modified lstm_tf2_rewriter to accommodate the new pattern in rewriter parsing. Although I can skip the extra SelectV2 pattern and get LSTM op in final onnx model, I am not able to correctly handle the masking zero information. My attempt will result in incorrect inference result if 0 is contained in input.

Below is my unsuccessful attempt: masking is ignored and results in incorrect inference result. Any suggestion on how masking should be handled? Thank you!

--- a/tf2onnx/rewriter/lstm_tf2_rewriter.py
+++ b/tf2onnx/rewriter/lstm_tf2_rewriter.py
@@ -56,21 +56,22 @@ def rewriter_lstm_tf2(g, ops):
             # extract output h_t
             ht_mul = match_result.get_op("ht")
             final_consumers = g.find_output_consumers(ht_mul.output[0])
-            select_ops = [n for n in final_consumers if n.type == "Select"]
+            select_ops = [n for n in final_consumers if n.type == "Select" or n.type == "SelectV2"]
             def has_tensor_list_consumer(n):
                 return any(c.type == "TensorListSetItem" for c in g.find_output_consumers(n.output[0]))
             select_ops = [n for n in select_ops if has_tensor_list_consumer(n)]
+
+            seq_len_idx = None
             if len(select_ops) == 1:
-                greater_eq = select_ops[0].inputs[0]
-                if greater_eq.type != "GreaterEqual":
-                    continue
-                seq_len = greater_eq.inputs[1]
-                if not seq_len.is_graph_input():
-                    continue
-                seq_len_idx = g.input_names.index(seq_len.output[0])
+                select_op_condition = select_ops[0].inputs[0]
+                if select_op_condition.type == "GreaterEqual": # has sequence length
+                    seq_len = select_op_condition.inputs[1]
+                    if not seq_len.is_graph_input():
+                        continue
+                    seq_len_idx = g.input_names.index(seq_len.output[0])
+                # if select op's condition doesn't come from GreaterEqual, we still extract
+                # output h_t from consumer, and seq_len remains empty
                 final_consumers = g.find_output_consumers(select_ops[0].output[0])
-            else:
-                seq_len_idx = None

Related links:

hwangdeyu commented 2 years ago

Related issue: Keras LSTM converted to loop instead of ONNX LSTM op https://github.com/onnx/tensorflow-onnx/issues/1851

q-ycong-p commented 2 years ago

@hwangdeyu The reason why it's converted to loop in this case is because the SelectV2 in frozen graph pattern is not handled in lstm_tf2_rewriter - itself can be skipped to produce LSTM op (see code attempting to fix above). But the key issue is, the SelectV2 is there because of masking, and ONNX LSTM layer does not support masking. If skipped, model inference will be incorrect with zero-containing input. This means, for tf2onnx to convert into LSTM op and for model to produce correct masking behavior, work needed is perhaps beyond simply getting the loop op into LSTM op.

As far as I understand, the masking behavior is an intra-op operation. cjermain in this post gave idea on solution to padding (not general masking) which is already implemented in keras2onnx - which tf2onnx can also adapt?

q-ycong-p commented 2 years ago

Hi @hwangdeyu and tf2onnx contributors,

This feature keras2onnx has implemented. I tested with the minimal example keras model shared above - keras2onnx implementation gives correct inference results on post-padded inputs. It does so by pre-processing the inputs before LSTM.

Can the same be implemented here to support masking for post-padded inputs?

hwangdeyu commented 2 years ago

Thank you so much for so detailed issue, I am not very familiar with this part. I guess the keras2onnx is a good example to fix it. Need to do more investigation for it. Is there any model blocked by this issues?

Hi @hwangdeyu Deyu Huang FTE and tf2onnx contributors,

This feature keras2onnx has implemented. I tested with the minimal example keras model shared above - keras2onnx implementation gives correct inference results on post-padded inputs. It does so by pre-processing the inputs before LSTM.

Can the same be implemented here to support masking for post-padded inputs?

q-ycong-p commented 2 years ago

Thank you so much for so detailed issue, I am not very familiar with this part. I guess the keras2onnx is a good example to fix it. Need to do more investigation for it. Is there any model blocked by this issues?

Hi @hwangdeyu Deyu Huang FTE and tf2onnx contributors, This feature keras2onnx has implemented. I tested with the minimal example keras model shared above - keras2onnx implementation gives correct inference results on post-padded inputs. It does so by pre-processing the inputs before LSTM. Can the same be implemented here to support masking for post-padded inputs?

Hi, thank you for following up. Yes, we have an older model trained with TF-2.4 that converts with keras2onnx correctly. With keras2onnx deprecated and TF upgraded to TF-2.8, we need to correctly convert the model with tf2onnx. The production model is tricky to share but above minimal example should reproduce the issue. Will the community be able to take a look soon?

q-ycong-p commented 2 years ago

Hi @hwangdeyu and tf2onnx community, is there any plan on picking up this work soon?

buddhapuneeth commented 2 years ago

@hwangdeyu we have some urgent request for this fix (07/01). Please let us know if you are able to pick this sometime soon.

hwangdeyu commented 2 years ago

Hi @hwangdeyu and tf2onnx community, is there any plan on picking up this work soon?

Yeah, sorry for late reply, we will make a plan to fix this recently.

q-ycong-p commented 2 years ago

Hi @hwangdeyu, could you give us an update? We'll need to work around the urgent request (07/01) we got depending on your plan for this fix. Pls let us know.

hwangdeyu commented 2 years ago

Hi @hwangdeyu Deyu Huang FTE, could you give us an update? We'll need to work around the urgent request (07/01) we got depending on your plan for this fix. Pls let us know.

We went offsite several days. And we are doing with this issue, but the progress has been limited so far. I‘m not sure it can be solved before July 1st..

q-ycong-p commented 2 years ago

@hwangdeyu Thanks for the update. Could you provide us with an ETA?

q-ycong-p commented 2 years ago

@hwangdeyu Given the request on our end, we need to work towards a fix too. Could you give us some pointer/direction so that we can contribute/collaborate?

hwangdeyu commented 2 years ago

@hwangdeyu Given the request on our end, we need to work towards a fix too. Could you give us some pointer/direction so that we can contribute/collaborate?

Just as https://github.com/onnx/onnx/issues/2248#issuecomment-587988420 said, the first type has been implemented in keras2onnx. The second type may be better, but I don't think ONNX will change the KSTM op cause the recent new models are less likely to use it. I don't think of a better way so far and my plan will sync with my colleague to try the same with kears2onnx to support masking for post-padded inputs. And thanks, I will ping you if I find anything need you guy contribute.. By the way, it's better to get a simple reproduce code to verify if our change is work for correct inference results. Actually, there was a keras CI test using the mask_zero==True can be convert into ONNX and get the same inference results with keras.

q-ycong-p commented 2 years ago

Thanks for the pointer. Regarding the reproduction, I think the minimal example shared at top of post is sufficiently simple to verify wip fix? We can assert both inference result to equal the TF result, and assert LSTM op is correctly converted. It seems that linked test only verify inference result equal, but doesn't verify LSTM op is produced (instead of loop)? Maybe I didn't understand your ask correctly?

hwangdeyu commented 2 years ago

Hi @q-ycong-p , I have synced with my colleague and did more tests about it. Just as the CI tests and your example shown, the inference results from original converted model Loop op is expected with keras inference results specifying mask_zero=True, no matter whether 0 is contained in input with. And it's also hard to support masking LSTM op with ONNX, so I'm sorry that we won't spent more efforts to do it. Hope you could understand it.

Thanks for the pointer. Regarding the reproduction, I think the minimal example shared at top of post is sufficiently simple to verify wip fix? We can assert both inference result to equal the TF result, and assert LSTM op is correctly converted. It seems that linked test only verify inference result equal, but doesn't verify LSTM op is produced (instead of loop)? Maybe I didn't understand your ask correctly?

q-ycong-p commented 2 years ago

Hi @hwangdeyu, thanks for letting me know. Even inference result is correct, LSTM being converted to loop op means important optimizations cannot be taken advantaged of at inference time. I'm trying to work towards a fix for the post-padded masking scenario, using the example @cjermain from keras2onnx has provided. Will keep it posted and might seek help if stuck.

AndreyOrb commented 8 months ago

Hello, Is there any update/eta on this issue?