Closed sklum closed 1 year ago
The price of onnx2tf's considerably higher model optimization efficiency compared to onnx-tensorflow is that the optimization operation may fail if there are two or more undefined dimensions in the input tensor. If there is a series of tensors with axis size None, the correct axis position will be lost in the process of model transformation.
Although the conversion of models with multiple undefined dimensions is originally supported, the probability of model conversion failure is higher, and the user must compensate for conversion errors. JSON files can be used to compensate for the axis transposition behavior of onnx2tf.
In the case of RetinaFace, there was an error in the axis correction of Gather
and Concat
. When the correction was instructed in JSON, the conversion was successful, and the inference operation could be performed without any problems and with variable axes.
If you do not understand what I am saying, I would not recommend dealing with a model that has a high conversion difficulty involving undefined dimensions.
When running the above, the following is the first shape issue I have:
INFO: onnx_output_name: wa/fpn/Shape_3_output_0 tf_output_name: tf.compat.v1.shape/wa/fpn/Shape_3:0 shape: (4,) dtype: int64 validate_result: Skipped (Deleted or Shape Unmatched)
Skipped (Deleted or Shape Unmatched)
appears in all 1D OP output, so most can safely be ignored. If it appears in more than two dimensions of OP output, it suggests that the OP transposition operation has failed somewhere prior to that OP. In the case of your RetinaFace, I was getting a lot of these warning messages for all outputs above 2 dimensions immediately after Gather
and Concat
. The Gather
and Concat
operations are used to derive the tensor size for the Resize
immediately following it. Thus, if onnx2tf misunderstands the Gather
axis and the Concat
axis, as you have posted this time, an inconsistent tensor will be generated by onnx2tf, resulting in an infeasible model. Wrong: [1,64,28,64]
https://github.com/PINTO0309/onnx2tf/releases/tag/1.18.2
pip install onnx2tf -U
wget https://github.com/PINTO0309/onnx2tf/releases/download/1.16.31/flatc.tar.gz \
&& tar -zxvf flatc.tar.gz \
&& sudo chmod +x flatc \
&& sudo mv flatc /usr/bin/
replace_retinaface_dynamic.json
{
"format_version": 1,
"operations": [
{
"op_name": "/fpn/Gather",
"param_target": "inputs",
"param_name": "/fpn/Constant_output_0",
"values": 1
},
{
"op_name": "/fpn/Gather_1",
"param_target": "inputs",
"param_name": "/fpn/Constant_1_output_0",
"values": 2
},
{
"op_name": "/fpn/Concat_1",
"param_target": "outputs",
"param_name": "/fpn/Concat_1_output_0",
"post_process_transpose_perm": [0,2,3,1]
},
{
"op_name": "/fpn/Gather_2",
"param_target": "inputs",
"param_name": "/fpn/Constant_output_0",
"values": 1
},
{
"op_name": "/fpn/Gather_3",
"param_target": "inputs",
"param_name": "/fpn/Constant_1_output_0",
"values": 2
},
{
"op_name": "/fpn/Concat_3",
"param_target": "outputs",
"param_name": "/fpn/Concat_3_output_0",
"post_process_transpose_perm": [0,2,3,1]
}
]
}
convert
onnx2tf \
-i retinaface_onnx_dynamic.onnx \
-prf replace_retinaface_dynamic.json \
-osd \
-coion \
-cotof
tensorflowjs_converter \
--input_format tf_saved_model \
--output_format tfjs_graph_model \
saved_model \
tfjs_model
converted models
saved_model_cli show --dir saved_model/ --all
MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:
signature_def['__saved_model_init_op']:
The given SavedModel SignatureDef contains the following input(s):
The given SavedModel SignatureDef contains the following output(s):
outputs['__saved_model_init_op'] tensor_info:
dtype: DT_INVALID
shape: unknown_rank
name: NoOp
Method name is:
signature_def['serving_default']:
The given SavedModel SignatureDef contains the following input(s):
inputs['input_1'] tensor_info:
dtype: DT_FLOAT
shape: (1, -1, -1, 3)
name: serving_default_input_1:0
The given SavedModel SignatureDef contains the following output(s):
outputs['525'] tensor_info:
dtype: DT_FLOAT
shape: (1, -1, 4)
name: PartitionedCall:0
outputs['606'] tensor_info:
dtype: DT_FLOAT
shape: (1, -1, 10)
name: PartitionedCall:1
outputs['607'] tensor_info:
dtype: DT_FLOAT
shape: (1, -1, 2)
name: PartitionedCall:2
Method name is: tensorflow/serving/predict
tflite tests - Assumption that tflite and tfjs behave the same unless there is a bug in the tfjs converter
import numpy as np
import tensorflow as tf
from pprint import pprint
interpreter = tf.lite.Interpreter(model_path="retinaface_onnx_dynamic_float32.tflite")
tf_lite_model = interpreter.get_signature_runner()
inputs = {
'input_1': np.ones([1,480,640,3], dtype=np.float32),
}
tf_lite_output = tf_lite_model(**inputs)
print(f"[TFLite] Model Predictions shape: {tf_lite_output['525'].shape}")
print(f"[TFLite] Model Predictions shape: {tf_lite_output['606'].shape}")
print(f"[TFLite] Model Predictions shape: {tf_lite_output['607'].shape}")
print(f"[TFLite] Model Predictions:")
pprint(tf_lite_output)
[TFLite] Model Predictions shape: (1, 12600, 4)
[TFLite] Model Predictions shape: (1, 12600, 10)
[TFLite] Model Predictions shape: (1, 12600, 2)
[TFLite] Model Predictions:
{'525': array([[[ 0.5037499 , -0.53452146, 0.28746593, -0.1060854 ],
[ 0.0795919 , 0.11518952, 0.2050548 , 0.18294893],
[ 0.14049551, 0.08239099, 0.09347238, 0.07261422],
...,
[ 0.1204751 , 0.4000509 , -0.9219683 , -1.0032947 ],
[ 0.6076698 , 1.0467663 , -1.3310252 , -0.77822804],
[ 0.93253464, 0.9066995 , -1.2800018 , -0.8087649 ]]],
dtype=float32),
'606': array([[[-0.779792 , -1.3634459 , -0.9822485 , ..., -1.5434346 ,
-1.5022398 , -1.5166504 ],
[-1.4359287 , -1.416552 , -1.53112 , ..., -1.5689636 ,
-1.5718539 , -1.5501782 ],
[-1.5408219 , -1.5390366 , -1.5441248 , ..., -1.5436357 ,
-1.5436357 , -1.5436357 ],
...,
[-0.22368906, -0.33528268, 0.43932757, ..., 0.62453854,
0.64479786, 0.5371737 ],
[ 0.25273618, -0.30643156, 1.4769835 , ..., 1.2033877 ,
0.99543476, 1.2796272 ],
[ 0.49980396, -0.35802573, 1.6034117 , ..., 1.1993165 ,
1.3116333 , 1.3412124 ]]], dtype=float32),
'607': array([[[4.9240157e-01, 5.0759840e-01],
[4.8854157e-01, 5.1145840e-01],
[5.3600717e-01, 4.6399277e-01],
...,
[9.9783522e-01, 2.1648051e-03],
[9.9984205e-01, 1.5789027e-04],
[9.9987757e-01, 1.2243164e-04]]], dtype=float32)}
import numpy as np
import tensorflow as tf
from pprint import pprint
interpreter = tf.lite.Interpreter(model_path="retinaface_onnx_dynamic_float32.tflite")
tf_lite_model = interpreter.get_signature_runner()
inputs = {
'input_1': np.ones([1,192,320,3], dtype=np.float32),
}
tf_lite_output = tf_lite_model(**inputs)
print(f"[TFLite] Model Predictions shape: {tf_lite_output['525'].shape}")
print(f"[TFLite] Model Predictions shape: {tf_lite_output['606'].shape}")
print(f"[TFLite] Model Predictions shape: {tf_lite_output['607'].shape}")
print(f"[TFLite] Model Predictions:")
pprint(tf_lite_output)
[TFLite] Model Predictions shape: (1, 2520, 4)
[TFLite] Model Predictions shape: (1, 2520, 10)
[TFLite] Model Predictions shape: (1, 2520, 2)
[TFLite] Model Predictions:
{'525': array([[[ 0.5037502 , -0.534521 , 0.287466 , -0.1060854 ],
[ 0.07959143, 0.11518921, 0.20505464, 0.18294904],
[ 0.14049558, 0.08239092, 0.09347239, 0.07261468],
...,
[ 0.06551935, 0.4618321 , -0.94103575, -1.0230165 ],
[ 0.49285072, 0.9942304 , -1.325173 , -0.7776669 ],
[ 0.8382109 , 0.94311976, -1.3364108 , -0.869154 ]]],
dtype=float32),
'606': array([[[-0.7797919 , -1.3634459 , -0.9822487 , ..., -1.5434344 ,
-1.5022393 , -1.5166501 ],
[-1.4359276 , -1.4165508 , -1.5311191 , ..., -1.5689638 ,
-1.5718228 , -1.5446154 ],
[-1.5317626 , -1.5467222 , -1.5688723 , ..., -1.330126 ,
-1.4217492 , -1.3949665 ],
...,
[-0.254573 , -0.24012746, 0.38229224, ..., 0.6902082 ,
0.59906906, 0.5935865 ],
[ 0.20296118, -0.2921868 , 1.4164804 , ..., 1.1968634 ,
0.98056495, 1.2392025 ],
[ 0.4711579 , -0.26603216, 1.5369333 , ..., 1.2316816 ,
1.2790531 , 1.3474693 ]]], dtype=float32),
'607': array([[[4.9240148e-01, 5.0759852e-01],
[4.8854169e-01, 5.1145828e-01],
[5.3600734e-01, 4.6399271e-01],
...,
[9.9757117e-01, 2.4287710e-03],
[9.9981421e-01, 1.8576751e-04],
[9.9986351e-01, 1.3642981e-04]]], dtype=float32)}
Redundant ONNX output from PyTorch was improved by performing a proprietary optimization to eliminate the need for JSON creation.
https://github.com/PINTO0309/onnx2tf/releases/tag/1.18.3
Resize
Resize
with undefined dimensions.-coion
.
onnx2tf \
-i retinaface_onnx_dynamic.onnx \
-osd \
-coion \
-cotof
onnx | Before tflite | After tflite |
---|---|---|
Good luck.
Thanks for all your help @PINTO0309. Everything here makes sense, but I wasn't able to get to this work today. I'll review in detail on Monday and let you know if I run into any additional issues with the updates.
Alright @PINTO0309, we're making great progress here. I can now run the dynamic input model without any shape issues on the tfjs
side. However, I am having issues with the "correctness" of the output that I'm hoping you can help me with.
Circling back to the pytorch
-> onnx
conversion code above, I've changed the process to be as follows:
img = np.float32(np.ones([375,448,3]))
img -= (104, 117, 123)
img = img.transpose(2, 0, 1)
img = np.expand_dims(img, 0)
example_input = torch.from_numpy(img)
torch_out = torch_model(example_input)
This matches the general functionality in detect.py
from here.
I have updated the onnx_forward
function to return all the output tensors instead of the first. Then, as before, I check the model after conversion with:
onnx_model = onnx.load(output_file)
onnx.checker.check_model(onnx_model, full_check=True)
onnx_out = onnx_forward(output_file, example_input)
np.testing.assert_almost_equal(torch_out[1].data.numpy(), onnx_out[1], decimal=3)
This passes successfully. Printing the confidence outputs of these models (torch_out[1]
and onnx_out[1]
) at this point we get a (1, 6944, 2)
tensor from both, where the values in the final dim of the tensor are [~1, ~0]
across all 6944
elements. This makes sense as this model is trained and so it shouldn't be detecting anything in an input of ones
.
Now, setting aside tfjs
, if we run the tflite
code you provided using the same input as above with a NHWC input instead of NCHW because of onnx2tf
:
import numpy as np
import tensorflow as tf
from pprint import pprint
interpreter = tf.lite.Interpreter(model_path="retinaface_onnx_dynamic_float32.tflite")
tf_lite_model = interpreter.get_signature_runner()
img = np.float32(np.ones([375,448,3]))
img -= (104, 117, 123)
img = np.expand_dims(img, 0)
inputs = {
'input_1': img,
}
tf_lite_output = tf_lite_model(**inputs)
print(f"[TFLite] Model Predictions shape: {tf_lite_output['525'].shape}")
print(f"[TFLite] Model Predictions shape: {tf_lite_output['606'].shape}")
print(f"[TFLite] Model Predictions shape: {tf_lite_output['607'].shape}")
print(f"[TFLite] Model Predictions:")
pprint(tf_lite_output)
The tf_lite_output
for the confidences is:
'607': array([[[4.6728298e-01, 5.3271705e-01],
[4.7566253e-01, 5.2433747e-01],
[5.1309991e-01, 4.8690012e-01],
...,
[9.9740821e-01, 2.5917361e-03],
[9.9948138e-01, 5.1866425e-04],
[9.9974173e-01, 2.5822222e-04]]], dtype=float32)
Note that multiple elements that are [~.5, ~.5]
. This is basically the same as the output you were showing in your comment above (and this aligns with what I'm getting on the tfjs
side).
So, am I doing something wrong with the input shapes or transpositions here? What am I doing wrong such that the -cotof
option isn't catching this? I assume something is going wrong with the expectation of transposition in the model, but it's not clear to me at the moment.
I was late checking the issue because I was training other models.
Models containing more than one None have a non-zero chance of making a transposition error. The checks performed in -cotof
force a comparison between the NCHW tensor and the NHWC tensor.
Forced meaning compares, for example, NCHW: [1,9,128,9]
with NHWC: [1,128,9,9]
. To begin with, it is necessary to compare tensor values based on the assumption that the tensor shapes of ONNX and PyTorch are completely different from those of TensorFlow, so a brute force check is used to replace all combinations of each axis to find the arrangement with the smallest error before calculating the error.
Thus, if you are unlucky enough to have a model with a structure like [1,9,9,9,9] in the middle of the model, the check itself will succeed correctly, but the axis of the model transformation itself may still be wrong.
Structural checking of a model with multiple None is quite difficult even with the human eye, but for when such a situation arises, we have a function to check where we have made a mistake in transforming the model.
The -onimc
option stops the conversion halfway through the model and outputs the model converted halfway through. It is a bit tedious work, but you need to try several transformations up to the midpoint of the model and run an inference test each time to see what part of the model you are mis-transposing.
For example,
# Split the model at the middle position for debugging
# Specify the output name of the OP
$ wget https://github.com/PINTO0309/onnx2tf/releases/download/0.0.2/resnet18-v1-7.onnx
$ onnx2tf -i resnet18-v1-7.onnx -onimc resnetv15_stage2_conv1_fwd resnetv15_stage2_conv2_fwd
Once you know where you have made a transposition error on an axis, you can use JSON to correct the transposition error.
For example,
https://github.com/PINTO0309/onnx2tf#parameter-replacement
https://github.com/PINTO0309/onnx2tf/blob/main/replace.json
# Parameter replacement (Resize,Transpose,Softmax)
$ rm replace.json
$ wget https://github.com/PINTO0309/onnx2tf/releases/download/1.1.27/human_segmentation_pphumanseg_2021oct.onnx
$ wget https://github.com/PINTO0309/onnx2tf/releases/download/1.1.27/replace.json
$ onnx2tf -i human_segmentation_pphumanseg_2021oct.onnx -prf replace.json
The reason why the -cotof
check is likely to be Matches
even though the axis is in the wrong position is OP, which does not rewrite the value itself, such as Gather
. (This is a possibility, not a definitive list of problem areas for RetinaFace.)
It would be hard to blindly examine the wrong areas, so if I were you, I would venture to generate a fixed-resolution RetinaFace model and compare its structure with the model with None. This will make it easier to understand to some extent where Transpose
is lacking, or conversely, where useless Transpose
is extrapolated.
If I get enough time during the holidays I will check out the model too.
Issue Type
Others
OS
Linux
onnx2tf version number
1.18.1
onnx version number
1.14.1
onnxruntime version number
1.16.0
onnxsim (onnx_simplifier) version number
0.4.31
tensorflow version number
2.14.0
Download URL for ONNX
https://drive.google.com/file/d/1XIRHjWYzWHwsZXOgcT4RLOJ6kxXE1BVT/view?usp=share_link
Parameter Replacement JSON
Description
Hi @PINTO0309, thanks for the great tool and all of your hard work.
I'm trying to convert a custom model from
onnx
totensorflow
totfjs
with a dynamic input shape and am having problems.As an example, take the
mobilenet0.25_Final.pth
model from https://github.com/biubug6/Pytorch_Retinaface.I'm converting from
pytorch
toonnx
using the following (where theRetinaFace
definition comes from here):I check that the
onnx
model works on the python side with the following:I then convert the
onnx
model totf
using:onnx2tf -i retinaface.onnx -osd -o retinaface_tf -cotof
When running the above, the following is the first shape issue I have:
INFO: onnx_output_name: wa/fpn/Shape_3_output_0 tf_output_name: tf.compat.v1.shape/wa/fpn/Shape_3:0 shape: (4,) dtype: int64 validate_result: Skipped (Deleted or Shape Unmatched)
Beyond a number of errors / warnings like these, the model converts successfully, but when using it in my
tfjs
-based system (after converting withtensorflowjs_converter
) I get the following shape mismatch at inference time:If remove the
dynamic_axes
, things work fine at the fixed input size. There are a number of layers with the[1,28,28,64]
shape, so it's been challenging to track down which is the problematic layer.FWIW, I've also tried wit the
-kat
and-nuo
options.This error doesn't happen during this workflow, but does the
Alternatively, if the input OP has a dynamic dimension, use the
-bor
-oisoption to rewrite it to a static shape and try again.
error message appearing in other places mean that dynamic shapes are not supported? Based on your recent commits I assume that they are indeed supported in some way.Is there simply a need for a parameter replacement in my case or am I hitting an edge case in dynamic inputs somehow? Any guidance would be appreciated. Please let me know if more information / models would be helpful - I will provide what I can.