onnx / tensorflow-onnx

Convert TensorFlow, Keras, Tensorflow.js and Tflite models to ONNX
Apache License 2.0
2.31k stars 433 forks source link

Understanding Einsum optimizations: Why alone equations aecd,abcd→acbe, acbe,aecd→abcd are decomposed? #1736

Closed buddhapuneeth closed 3 years ago

buddhapuneeth commented 3 years ago

Describe the bug I am trying to convert TinyBert model to ONNX. In the process, with latest tf2onnx: it was able to decompose einsum operators with equations: aecd,abcd→acbe, acbe,aecd→abcd to low level ops like reduce, concat, gather and few other ops combinations.

But with equations abc,cde→abde, abcd,cde→abe are left as such as einsum operator.

I tried to under the einsum_optimizer code it was bit overwhelming.

I would like to know Is there a way I can decompose all the einsum operators? including abc,cde→abde, abcd,cde→abe?

In general I would like to know which patterns are decomposed and which are not?

TomWildenhain-Microsoft commented 3 years ago

@xadupre

xadupre commented 3 years ago

These equations should work. This is what I got: decompose_einsum.zip. Did you see any error message in the log? The operator einsum is not replaced due to a failure but it should appear in the log.

There is one case the optimizer does not do, it is diagonal sum is ii->i (or any equation duplicating the same letter in one term). Currently, there is no way to replace that operation more efficiently than einsum.

I used the following notebook to produce those ONNX. I'll check with tf2onnx tomorrow.

buddhapuneeth commented 3 years ago

@xadupre Thanks for the response. So, ideally except diagonal sum all einsum ops should be replaced?

In my logs, I didn't see any error but having this log: 2021-10-05 07:49:07,633 - INFO - replacing einsum node 'StatefulPartitionedCall/model/self_attention/einsum/Einsum' by its decomposed version, name of the last node 'Identity__206'.

xadupre commented 3 years ago

I used tf2onnx and in both cases, it worked. You can run the unit test I made to create the tf model and the command line to convert (attached the result einsum.zip). The optimizer need an opset>=13 to be working. It does not work when some inputs are defined as attributes (such as Squeeze<13). When it fails, you should be able to see WARNING - Failed to apply einsum_optimizer in the log following by an exception such as RuntimeError: Opset (9) must be >= 11 for operator 'expand_dims'..

buddhapuneeth commented 3 years ago

@xadupre I am not seeing the exception you mentioned in the logs. You can see full logs here. I am trying to convert the Tiny Bert open-source model. You can download that model here. It will still have 8 einsum ops of patterns abc,cde→abde, abcd,cde→abe

Command used: python -m tf2onnx.convert --saved-model tb --output onnx_conv/model.onnx --opset 13 --v

buddhapuneeth commented 3 years ago

@xadupre after bit of debugging the non-decomposed einsums are due to node.inputs[1].is_const() check is True. Can I know the reason for this check?

buddhapuneeth commented 3 years ago

@xadupre @TomWildenhain-Microsoft I locally removed the check node.inputs[1].is_const() from einsum_optimizer.py and compared the results with and without the fix and results are matching in my case. I tired to understand the reason for this check in the code, and felt there is no risk. I might be wrong, please correct me. I used below code to compare:

# pip install onnxruntime
import numpy as np
import onnxruntime
from tensorflow.keras.models import load_model as load_model_tf_keras
np.random.seed(0)
input_mask = np.random.randint(10, size=(2, 10))
input_type_ids = np.random.randint(10, size=(2, 10))
input_word_ids = np.random.randint(10, size=(2, 10))
sess = onnxruntime.InferenceSession("with_out_fix.onnx")  # model with einsum ops
result = sess.run(["default","encoder_outputs","pooled_output","sequence_output","Identity_4:0"], {'input_mask': input_mask.astype(np.int32), 'input_type_ids':input_type_ids.astype(np.int32), 'input_word_ids':input_word_ids.astype(np.int32)})
sess1 = onnxruntime.InferenceSession("with_fix.onnx")  # with removing const check
result1 = sess1.run(["default","encoder_outputs","pooled_output","sequence_output","Identity_4:0"], {'input_mask': input_mask.astype(np.int32), 'input_type_ids':input_type_ids.astype(np.int32), 'input_word_ids':input_word_ids.astype(np.int32)})
print("#################comparisions#########")
print(np.array_equal(np.asarray(result[0]), np.asarray(result1[0])))
print(np.array_equal(np.asarray(result[1]), np.asarray(result1[1])))
print(np.array_equal(np.asarray(result[2]), np.asarray(result1[2])))
print(np.array_equal(np.asarray(result[3]), np.asarray(result1[3])))
print(np.array_equal(np.asarray(result[4]), np.asarray(result1[4])))

I would be needing this change in the earliest and any inputs on this will be really helpful. If you suggest any alt fix I am willing to do the change and contribute back.

xadupre commented 3 years ago

That should not be an issue. I removed this line in #1739.