PINTO0309 / onnx2tf

Self-Created Tools to convert ONNX files (NCHW) to TensorFlow/TFLite/Keras format (NHWC). The purpose of this tool is to solve the massive Transpose extrapolation problem in onnx-tensorflow (onnx-tf). I don't need a Star, but give me a pull request.
MIT License
706 stars 73 forks source link

Error by broadcasting in tf.math.Multiply Operation #698

Closed JihwanEom closed 1 month ago

JihwanEom commented 1 month ago

Issue Type

Others

OS

Linux

onnx2tf version number

1.24.1

onnx version number

1.16.1

onnxruntime version number

1.18.1

onnxsim (onnx_simplifier) version number

0.4.33

tensorflow version number

2.16.1

Download URL for ONNX

https://drive.google.com/file/d/1uEzWl53YZVdkWntVouc8Nyv9iWvNVGMw/view?usp=sharing

Parameter Replacement JSON

N/A

Description

Hello,

I'm experiencing an error while trying to convert a model to TFLite using the command below:

Command: onnx2tf -i 1005_s0_nonar_text_decoder.onnx -o 1005_s0_nonar_text_decoder -odrqt --disable_group_convolution --replace_to_pseudo_operators Erf --replace_argmax_to_reducemax_and_indices_is_int64

INFO: 16 / 223
INFO: onnx_op_type: Add onnx_op_name: /transformer.0/Add
INFO:  input_name.1: /transformer.0/token_mixer/reparam_conv/Conv_output_0 shape: [1, 512, 1, 81] dtype: float32
INFO:  input_name.2: /transformer.0/Mul_output_0 shape: [1, 512, 1, 81] dtype: float32
INFO:  output_name.1: /transformer.0/Add_output_0 shape: [1, 512, 1, 81] dtype: float32
INFO: tf_op_type: add
INFO:  input.1.x: name: tf.math.add_1/Add:0 shape: (1, 1, 81, 512) dtype: <dtype: 'float32'> 
INFO:  input.2.y: name: tf.math.multiply_16/Mul:0 shape: (1, 512, 81, 512) dtype: <dtype: 'float32'> 
INFO:  output.1.output: name: tf.math.add_7/Add:0 shape: (1, 512, 81, 512) dtype: <dtype: 'float32'> 

INFO: 17 / 223
INFO: onnx_op_type: Squeeze onnx_op_name: /transformer.0/Squeeze
INFO:  input_name.1: /transformer.0/Add_output_0 shape: [1, 512, 1, 81] dtype: float32
INFO:  input_name.2: /transformer.0/Constant_1_output_0 shape: (1,) dtype: int64
INFO:  output_name.1: /transformer.0/Squeeze_output_0 shape: [1, 512, 81] dtype: float32
ERROR: The trace log is below.
Traceback (most recent call last):
  File "/root/workspace/sharing/ws-jihwan/mobileclip_ws/onnx2tf/onnx2tf/utils/common_functions.py", line 312, in print_wrapper_func
    result = func(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^
  File "/root/workspace/sharing/ws-jihwan/mobileclip_ws/onnx2tf/onnx2tf/utils/common_functions.py", line 385, in inverted_operation_enable_disable_wrapper_func
    result = func(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^
  File "/root/workspace/sharing/ws-jihwan/mobileclip_ws/onnx2tf/onnx2tf/utils/common_functions.py", line 55, in get_replacement_parameter_wrapper_func
    func(*args, **kwargs)
  File "/root/workspace/sharing/ws-jihwan/mobileclip_ws/onnx2tf/onnx2tf/ops/Squeeze.py", line 192, in make_node
    tf.squeeze(
  File "/root/workspace/sharing/ws-jihwan/miniforge3/envs/onnx2tf/lib/python3.12/site-packages/tensorflow/python/ops/weak_tensor_ops.py", line 88, in wrapper
    return op(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/root/workspace/sharing/ws-jihwan/miniforge3/envs/onnx2tf/lib/python3.12/site-packages/tensorflow/python/util/traceback_utils.py", line 153, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/root/workspace/sharing/ws-jihwan/miniforge3/envs/onnx2tf/lib/python3.12/site-packages/tf_keras/src/layers/core/tf_op_layer.py", line 119, in handle
    return TFOpLambda(op)(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/workspace/sharing/ws-jihwan/miniforge3/envs/onnx2tf/lib/python3.12/site-packages/tf_keras/src/utils/traceback_utils.py", line 72, in error_handler
    del filtered_tb
        ^^^^^^^^^^^
ValueError: Exception encountered when calling layer "tf.squeeze_1" (type TFOpLambda).

Can not squeeze dim[1], expected a dimension of 1, got 512 for '{{node tf.squeeze_1/Squeeze}} = Squeeze[T=DT_FLOAT, squeeze_dims=[1]](Placeholder)' with input shapes: [1,512,81,512].

Call arguments received by layer "tf.squeeze_1" (type TFOpLambda):
  • input=tf.Tensor(shape=(1, 512, 81, 512), dtype=float32)
  • axis=['1']
  • name='/transformer.0/Squeeze'

ERROR: input_onnx_file_path: 1005_s0_nonar_text_decoder.onnx
ERROR: onnx_op_name: /transformer.0/Squeeze
ERROR: Read this and deal with it. https://github.com/PINTO0309/onnx2tf#parameter-replacement
ERROR: Alternatively, if the input OP has a dynamic dimension, use the -b or -ois option to rewrite it to a static shape and try again.
ERROR: If the input OP of ONNX before conversion is NHWC or an irregular channel arrangement other than NCHW, use the -kt or -kat option.
ERROR: Also, for models that include NonMaxSuppression in the post-processing, try the -onwdt option.

The input shape of my model is [1, 1, 512] (Batch, 1, hidden_dim), and the output shape is [Batch, 1+N, vocab_size].

Upon reviewing the following logs:

INFO: 16 / 223
INFO: onnx_op_type: Add onnx_op_name: /transformer.0/Add
INFO:  input_name.1: /transformer.0/token_mixer/reparam_conv/Conv_output_0 shape: [1, 512, 1, 81] dtype: float32
INFO:  input_name.2: /transformer.0/Mul_output_0 shape: [1, 512, 1, 81] dtype: float32
INFO:  output_name.1: /transformer.0/Add_output_0 shape: [1, 512, 1, 81] dtype: float32
INFO: tf_op_type: add
INFO:  input.1.x: name: tf.math.add_1/Add:0 shape: (1, 1, 81, 512) dtype: <dtype: 'float32'> 
INFO:  input.2.y: name: tf.math.multiply_16/Mul:0 shape: (1, 512, 81, 512) dtype: <dtype: 'float32'> 
INFO:  output.1.output: name: tf.math.add_7/Add:0 shape: (1, 512, 81, 512) dtype: <dtype: 'float32'> 

It seems that while both inputs 1 and 2 for ONNX are correctly [1, 512, 1, 81], input 2 of tf_op is unexpectedly broadcast to [1, 512, 81, 512] instead of [1, 1, 81, 512].

Could you please provide some advice on how to resolve this? Thank you!

PINTO0309 commented 1 month ago
JihwanEom commented 1 month ago

Thank you so much for the incredibly quick review and fix! I've confirmed that the model is now working perfectly with 1.25.15 version (manually reflected the changes from PR). I really appreciate you taking the time to look into this, especially late on a weekend night. I'm so happy it's working well now. Wishing you a wonderful weekend!

JihwanEom commented 1 month ago

Dear @PINTO0309, sorry to bother you. Image encoder ONNX checkpoint: link When I tried converting the ONNX file after the manually updating the codebase, I encountered the following error. command: onnx2tf -i 1006_s0_nonar_image_encoder_latest.onnx

INFO: 287 / 346
INFO: onnx_op_type: Reshape onnx_op_name: /model/network.7/network.7.0/token_mixer/Reshape_3
INFO:  input_name.1: /model/network.7/network.7.0/token_mixer/Transpose_4_output_0 shape: [1, 512, 64] dtype: float32
INFO:  input_name.2: /model/network.7/network.7.0/token_mixer/Constant_11_output_0 shape: [4] dtype: int64
INFO:  output_name.1: /model/network.7/network.7.0/token_mixer/Reshape_3_output_0 shape: [1, 512, 8, 8] dtype: float32
INFO: tf_op_type: reshape
INFO:  input.1.tensor: name: tf.compat.v1.transpose_44/transpose:0 shape: (1, 512, 64) dtype: <dtype: 'float32'> 
INFO:  input.2.shape: val: [1, 512, 8, 8] 
INFO:  output.1.output: name: tf.reshape_11/Reshape:0 shape: (1, 512, 8, 8) dtype: <dtype: 'float32'> 

INFO: 288 / 346
INFO: onnx_op_type: Mul onnx_op_name: /model/network.7/network.7.0/Mul
INFO:  input_name.1: model.network.7.0.layer_scale_1 shape: [512, 1, 1] dtype: float32
INFO:  input_name.2: /model/network.7/network.7.0/token_mixer/Reshape_3_output_0 shape: [1, 512, 8, 8] dtype: float32
INFO:  output_name.1: /model/network.7/network.7.0/Mul_output_0 shape: [1, 512, 8, 8] dtype: float32
ERROR: The trace log is below.
Traceback (most recent call last):
  File "/root/workspace/sharing/ws-jihwan/mobileclip_ws/onnx2tf/onnx2tf/utils/common_functions.py", line 312, in print_wrapper_func
    result = func(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^
  File "/root/workspace/sharing/ws-jihwan/mobileclip_ws/onnx2tf/onnx2tf/utils/common_functions.py", line 385, in inverted_operation_enable_disable_wrapper_func
    result = func(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^
  File "/root/workspace/sharing/ws-jihwan/mobileclip_ws/onnx2tf/onnx2tf/utils/common_functions.py", line 55, in get_replacement_parameter_wrapper_func
    func(*args, **kwargs)
  File "/root/workspace/sharing/ws-jihwan/mobileclip_ws/onnx2tf/onnx2tf/ops/Mul.py", line 244, in make_node
    correction_process_for_accuracy_errors(
  File "/root/workspace/sharing/ws-jihwan/mobileclip_ws/onnx2tf/onnx2tf/utils/common_functions.py", line 5853, in correction_process_for_accuracy_errors
    dummy_op = tf_func(
               ^^^^^^^^
  File "/root/workspace/sharing/ws-jihwan/miniforge3/envs/onnx2tf/lib/python3.12/site-packages/tensorflow/python/ops/weak_tensor_ops.py", line 142, in wrapper
    return op(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/root/workspace/sharing/ws-jihwan/miniforge3/envs/onnx2tf/lib/python3.12/site-packages/tensorflow/python/util/traceback_utils.py", line 153, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/root/workspace/sharing/ws-jihwan/miniforge3/envs/onnx2tf/lib/python3.12/site-packages/tf_keras/src/layers/core/tf_op_layer.py", line 119, in handle
    return TFOpLambda(op)(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/workspace/sharing/ws-jihwan/miniforge3/envs/onnx2tf/lib/python3.12/site-packages/tf_keras/src/utils/traceback_utils.py", line 72, in error_handler
    del filtered_tb
        ^^^^^^^^^^^
ValueError: Exception encountered when calling layer "tf.math.multiply_518" (type TFOpLambda).

Dimensions must be equal, but are 8 and 512 for '{{node tf.math.multiply_518/Mul}} = Mul[T=DT_FLOAT](Placeholder, tf.math.multiply_518/Mul/y)' with input shapes: [1,512,8,8], [1,1,1,512].

Call arguments received by layer "tf.math.multiply_518" (type TFOpLambda):
  • x=tf.Tensor(shape=(1, 512, 8, 8), dtype=float32)
  • y=array([[[[ 0.05679564,  0.12148159,  0.11233755,  0.04193629,
          -0.07566667, -0.10907646,  0.04781451,  0.07481429,
          -0.05444366,  0.07370226, -0.07062352, -0.09698167,
          -0.03944191, -0.07290409, -0.03765032, -0.08496998,
          -0.06623586, -0.03894256, -0.05981188,  0.07432929,
          -0.09311683, -0.09000562,  0.04840877,  0.08937823,
          -0.08537138,  0.04235391,  0.10113057,  0.0898895 ,
          -0.0335672 , -0.06119232,  0.04753408,  0.10236507,
          -0.08871685, -0.08038578, -0.08869054, -0.03617603,
           0.02431735, -0.0490169 ,  0.04268952,  0.06960726,
          -0.04321573, -0.06252784,  0.08141872, -0.06381831,
          -0.07017564, -0.04484068, -0.03707905, -0.05490652,
           0.06499621, -0.05915944,  0.11528529,  0.15309212,
           0.08457021, -0.05511831, -0.12379785,  0.04732216,
          -0.07085859, -0.11543021,  0.04016964,  0.05205216,
           0.07390954, -0.05206259, -0.06444984,  0.06366117,
           0.03512792, -0.06255513,  0.04096276,  0.06271195,
          -0.04709042, -0.07012043, -0.04555522, -0.0655094 ,
          -0.17915884, -0.0417478 ,  0.06166015,  0.04927208,
           0.04682605, -0.07510673, -0.05002141,  0.08131753,
          -0.07468081,  0.04881779, -0.17128612,  0.07790399,
           0.08647819, -0.10937636,  0.04686357,  0.06522253,
           0.07999425, -0.06203929,  0.10498064,  0.09189045,
           0.1565074 , -0.08703004,  0.04915918,  0.03403083,
          -0.11606346,  0.07693201,  0.04746893, -0.06833348,
          -0.03602915, -0.05741571, -0.05581505,  0.04001234,
           0.05873691, -0.03468559,  0.16453297,  0.04585794,
           0.05992495, -0.05190395,  0.06238254,  0.09447399,
           0.09370488, -0.06572337,  0.06037894,  0.10510889,
          -0.10933234,  0.05334195, -0.06172893,  0.07030273,
           0.06085391,  0.07938613, -0.03824306,  0.08946471,
           0.03965388,  0.06468146, -0.04696504,  0.05607357,
          -0.03315815,  0.1041488 ,  0.06035807, -0.07704061,
          -0.05392484,  0.09373773,  0.05237664, -0.06281945,
           0.07994758, -0.09561491, -0.04988258,  0.06231315,
          -0.07395851,  0.05785489, -0.05890709,  0.04903151,
           0.08003257, -0.09149866, -0.06473091,  0.14546002,
           0.11884373,  0.0626166 , -0.04222129,  0.10674977,
           0.06318466,  0.05161892,  0.05956442,  0.05220222,
           0.05465273, -0.02554615, -0.07085277,  0.08357469,
           0.05540628,  0.08505432,  0.04465118,  0.07918932,
          -0.12436352, -0.07674525,  0.08206783, -0.04604501,
           0.06561015, -0.10502517,  0.0595285 , -0.05393353,
           0.09417555, -0.07605029, -0.06903961, -0.0723069 ,
           0.0879845 ,  0.09977774, -0.10501694, -0.10962442,
           0.13180113, -0.0409829 , -0.03295548, -0.10995737,
          -0.14692496, -0.0827745 , -0.05973983,  0.09112202,
           0.07569475,  0.06859953, -0.05756547, -0.05540706,
          -0.06927562,  0.1317709 ,  0.04968894, -0.05015145,
           0.0513989 , -0.08164629,  0.0650223 ,  0.04716168,
           0.04792733,  0.05683443,  0.12517905,  0.09454681,
           0.0997954 ,  0.06717522, -0.05871446, -0.13058093,
           0.06684358,  0.06383131, -0.05130139, -0.07120853,
          -0.06592735,  0.05907459, -0.10004366, -0.11561153,
          -0.04857404,  0.15112498,  0.06738169, -0.11682976,
          -0.04895256, -0.05914875,  0.05559698,  0.22043861,
           0.07407195, -0.05302287,  0.10294298, -0.05959857,
           0.05060734,  0.06141803,  0.08949327, -0.05537126,
           0.04752927,  0.07410918, -0.12302282,  0.04733361,
           0.09135745,  0.10776824, -0.06696495, -0.08819423,
           0.05139291,  0.06776755,  0.06723671,  0.04167761,
           0.06528471, -0.06077628,  0.04481075, -0.06125391,
           0.03678311, -0.14232036,  0.06599596,  0.11149105,
          -0.10552928, -0.03707049, -0.04169705, -0.06365078,
           0.05001835, -0.10002147,  0.08229017, -0.07867017,
           0.1923649 , -0.09215724, -0.07901167, -0.13053909,
          -0.05080564,  0.05029891,  0.03528924,  0.06039648,
           0.09389147,  0.05717449, -0.05598213, -0.04926844,
           0.06236379,  0.04442395,  0.03006951,  0.05625453,
           0.05458543,  0.05971921, -0.06168054,  0.06774168,
          -0.04845234,  0.09556087, -0.0449145 ,  0.05125907,
           0.05363031,  0.09990506, -0.07019629, -0.05089007,
          -0.0426661 ,  0.03862907, -0.08243524, -0.06333328,
           0.06084452, -0.04175306,  0.05748864, -0.051601  ,
          -0.05983461, -0.0728577 ,  0.05551457, -0.03997757,
          -0.04979417,  0.03399378, -0.03364699,  0.04361123,
          -0.10430909,  0.07269166,  0.05340623,  0.05952214,
          -0.05177347, -0.05376231,  0.04586364, -0.07611496,
           0.02836897,  0.06754087, -0.04268967, -0.08099882,
           0.07238659, -0.03608365, -0.07900336, -0.04369306,
          -0.04884762,  0.05886111,  0.06094831, -0.05937698,
           0.06305402,  0.11477335,  0.10087421, -0.04940838,
          -0.05472771, -0.07860502, -0.06711181,  0.16069432,
           0.07015684, -0.03852618,  0.07886147, -0.05660112,
          -0.07531713,  0.05990539, -0.04459869,  0.0355062 ,
           0.05661491,  0.04969598,  0.04069342, -0.10449423,
          -0.0916004 ,  0.06401003, -0.05788952,  0.07201333,
           0.04136726, -0.07827135,  0.09534591,  0.06967531,
          -0.06156003, -0.06037319,  0.03470572,  0.07161678,
          -0.05398342,  0.06195046,  0.06441947,  0.06872239,
           0.04744927,  0.0549283 ,  0.07251038, -0.07547082,
           0.08767079,  0.03477118,  0.03224018,  0.032039  ,
          -0.05785936,  0.06199856, -0.07691963,  0.04357189,
           0.10109425, -0.08756435, -0.03981568,  0.07535966,
           0.06242423, -0.05595406, -0.07278702, -0.06807476,
          -0.06567898, -0.09367836, -0.09123997, -0.06118471,
           0.04156755,  0.05917402, -0.04296285,  0.04239249,
          -0.0617152 , -0.07752201,  0.06640301,  0.06742044,
           0.04558919,  0.06062859,  0.03683025, -0.03288598,
          -0.08338501,  0.07760349,  0.08199436,  0.05868538,
          -0.05750139, -0.10123982,  0.05516977,  0.06702373,
          -0.09002191,  0.09233421,  0.04301672,  0.07232062,
           0.06421718,  0.05352986,  0.08339141,  0.1152419 ,
           0.06182766, -0.07089094, -0.05114551, -0.12091893,
           0.15590471,  0.105652  ,  0.07612755, -0.11623123,
           0.0777195 ,  0.04730468,  0.0705837 ,  0.06228766,
          -0.03787732, -0.04884365, -0.06356364, -0.06206393,
           0.04411897, -0.06202469,  0.09414049,  0.13610202,
           0.14054692,  0.11828613, -0.07636502, -0.0499926 ,
          -0.09782722, -0.07463919, -0.0605814 ,  0.06501891,
           0.13765937,  0.06051449, -0.0687281 , -0.07879251,
           0.06436344,  0.19246407, -0.05184782, -0.11339917,
          -0.14125231, -0.0582628 ,  0.08216344, -0.10715538,
          -0.06659676,  0.04777445, -0.05756154, -0.03885875,
           0.12389274, -0.04496153, -0.05596035, -0.04643614,
           0.05494179,  0.05963946,  0.05108174,  0.06413967,
           0.06681975,  0.04706677,  0.05723342, -0.05675968,
          -0.04154629,  0.07581214,  0.0605702 , -0.08025458,
           0.05749731,  0.04922277,  0.05881116, -0.07751459,
          -0.05883431, -0.06126503, -0.04327364,  0.06899264,
           0.07009817, -0.10546891,  0.07354705, -0.04207575,
          -0.08582591,  0.05451256,  0.07208341,  0.14498542,
          -0.07629512,  0.10985703, -0.06572801,  0.07201846,
           0.05537425,  0.07231497,  0.04635528,  0.08662535,
           0.11863436, -0.16265734,  0.0492879 ,  0.07581893,
          -0.09244329,  0.09061304, -0.07065758,  0.02859462,
          -0.04254417, -0.06865087, -0.14534754, -0.03896706,
          -0.03401543, -0.15456069,  0.05631099,  0.14285778]]]],
      dtype=float32)
  • name=None

ERROR: input_onnx_file_path: 1006_s0_nonar_image_encoder.onnx
ERROR: onnx_op_name: /model/network.7/network.7.0/Mul
ERROR: Read this and deal with it. https://github.com/PINTO0309/onnx2tf#parameter-replacement
ERROR: Alternatively, if the input OP has a dynamic dimension, use the -b or -ois option to rewrite it to a static shape and try again.
ERROR: If the input OP of ONNX before conversion is NHWC or an irregular channel arrangement other than NCHW, use the -kt or -kat option.
ERROR: Also, for models that include NonMaxSuppression in the post-processing, try the -onwdt option.

The conversion worked successfully when I revert the changes. Could you please take a look at this?

Thanks in advance so much!

PINTO0309 commented 1 month ago

I'm sleepy today so I'll check it tomorrow.

PINTO0309 commented 1 month ago
PINTO0309 commented 1 month ago

Fix: https://github.com/PINTO0309/onnx2tf/releases/tag/1.26.0