PINTO0309 / openvino2tensorflow

This script converts the ONNX/OpenVINO IR model to Tensorflow's saved_model, tflite, h5, tfjs, tftrt(TensorRT), CoreML, EdgeTPU, ONNX and pb. PyTorch (NCHW) -> ONNX (NCHW) -> OpenVINO (NCHW) -> openvino2tensorflow -> Tensorflow/Keras (NHWC/NCHW) -> TFLite (NHWC/NCHW). And the conversion from .pb to saved_model and from saved_model to .pb and from .pb to .tflite and saved_model to .tflite and saved_model to onnx. Support for building environments with Docker. It is possible to directly access the host PC GUI and the camera to verify the operation. NVIDIA GPU (dGPU) support. Intel iHD GPU (iGPU) support.
MIT License
338 stars 40 forks source link

Different output of inference results between (ONNX and TFLITE 32bit). #86

Closed beekeeper23 closed 2 years ago

beekeeper23 commented 2 years ago

Issue Type

Support

OS

Windows

OS architecture

x86_64

Programming Language

Python

Framework

OpenVINO, PyTorch, ONNX, TensorFlow, TensorFlowLite

Download URL for ONNX / OpenVINO IR

https://drive.google.com/file/d/1gxdHSdOOo4oJWmrFkkBpXUyhVAeqkW2b/view?usp=sharing

Description

I tried to convert the yolov5 algorithm with ShuffleNet from https://github.com/ppogg/YOLOv5-Lite to tflite using the openvino2tensorflow, but the received TensorFlow model had different results on inference, compared to the onnx model.

Also, it is important to note that I used only 4d vectors in the post-processing part of the algorithm.

Below the JSON file and output of inference testing are provided.

Relevant Log Output

{
    "format_version": 2,
    "layers": [
        {
            "layer_id": "253",
            "type": "Const",
            "replace_mode": "direct",
            "values": [
                0,
                1,
                2,
                4,
                3
            ]
        },
        {
            "layer_id": "255",
            "type": "Reshape", 
            "replace_mode": "insert_after",
            "values": [
                16,
                16,
                6,
                3
            ]
        },
        {
            "layer_id": "256",
            "type": "Const",
            "replace_mode": "direct",
            "values": [
                3
            ]
        },
        {
            "layer_id": "263",
            "type": "Reshape",
            "replace_mode": "insert_after",
            "values": [
                16,
                16,
                2,
                1
            ]
        },
        {
            "layer_id": "269",
            "type": "Const",
            "replace_mode": "direct",
            "values": [
                0,2,0,0

            ]
        },
        {
            "layer_id": "270",
            "type": "Const",
            "replace_mode": "direct",
            "values": [
                0,4,0,0

            ]
        },
        {
            "layer_id": "271",
            "type": "Const",
            "replace_mode": "direct",
            "values": [
                1,1,1,1

            ]
        },
        {
            "layer_id": "276",
            "type": "Transpose",
            "replace_mode": "insert_after",
            "values": [
                1,
                2,
                3,
                0
            ]
        },
        {
            "layer_id": "278",
            "type": "Concat",
            "replace_mode": "change_axis",
            "values": 2
        },
        {
            "layer_id": "280",
            "type": "Transpose",
            "replace_mode": "insert_before",
            "values": [
              0,
              3,
              1,
              2
            ]
        },

        {
            "layer_id": "319",
            "type": "Const",
            "replace_mode": "direct",
            "values": [
                0,
                1,
                2,
                4,
                3
            ]
        },
        {
            "layer_id": "321",
            "type": "Reshape", 
            "replace_mode": "insert_after",
            "values": [
                8,
                8,
                6,
                3
            ]
        },
        {
            "layer_id": "322",
            "type": "Const",
            "replace_mode": "direct",
            "values": [
                3
            ]
        },
        {
            "layer_id": "329",
            "type": "Reshape",
            "replace_mode": "insert_after",
            "values": [
                8,
                8,
                2,
                1
            ]
        },
        {
            "layer_id": "330",
            "type": "Reshape", 
            "replace_mode": "insert_after",
            "values": [
                8,
                8,
                2,
                3
            ]
        },
        {
            "layer_id": "335",
            "type": "Const",
            "replace_mode": "direct",
            "values": [
                0,
                2,
                0,
                0
            ]
        },
        {
            "layer_id": "336",
            "type": "Const",
            "replace_mode": "direct",
            "values": [
                0,
                4,
                0,
                0
            ]
        },
        {
            "layer_id": "337",
            "type": "Const",
            "replace_mode": "direct",
            "values": [
                1,
                1,
                1,
                1
            ]
        },
        {
            "layer_id": "342",
            "type": "Transpose",
            "replace_mode": "insert_after",
            "values": [
                1,
                2,
                3,
                0
            ]
        },

        {
            "layer_id": "344",
            "type": "Concat",
            "replace_mode": "change_axis",
            "values": 2
        },
        {
            "layer_id": "346",
            "type": "Transpose",
            "replace_mode": "insert_before",
            "values": [
              0,
              3,
              1,
              2
            ]
        },

        {
            "layer_id": "385",
            "type": "Const",
            "replace_mode": "direct",
            "values": [
                0,
                1,
                2,
                4,
                3
            ]
        },
        {
            "layer_id": "387",
            "type": "Reshape", 
            "replace_mode": "insert_after",
            "values": [
                4,
                4,
                6,
                3
            ]
        },
        {
            "layer_id": "388",
            "type": "Const",
            "replace_mode": "direct",
            "values": [
                3
            ]
        },
        {
            "layer_id": "395",
            "type": "Reshape",
            "replace_mode": "insert_after",
            "values": [
                4,
                4,
                2,
                1
            ]
        },        
        {
            "layer_id": "401",
            "type": "Const",
            "replace_mode": "direct",
            "values": [
                0,
                2,
                0,
                0
            ]
        },
        {
            "layer_id": "402",
            "type": "Const",
            "replace_mode": "direct",
            "values": [
                0,
                4,
                0,
                0
            ]
        },
        {
            "layer_id": "403",
            "type": "Const",
            "replace_mode": "direct",
            "values": [
                1,
                1,
                1,
                1
            ]
        },
        {
            "layer_id": "408",
            "type": "Transpose",
            "replace_mode": "insert_after",
            "values": [
                1,
                2,
                3,
                0
            ]
        },        
        {
            "layer_id": "410",
            "type": "Concat",
            "replace_mode": "change_axis",
            "values": 2
        },
        {
            "layer_id": "412",
            "type": "Transpose",
            "replace_mode": "insert_before",
            "values": [
              0,
              3,
              1,
              2
            ]
        }

    ]
}

Source code for simple inference testing code

ONNX output @@@@@@@@@@@@@@@@@@@@@@@
elapsed time: 0.7963180541992188ms
shape: (1, 1008, 6)
[array([[[5.26182938e+00, 4.99601841e+00, 1.22561550e+01, 1.36645145e+01,
         9.45538282e-04, 9.85830545e-01],
        [1.34348326e+01, 5.08123970e+00, 1.37564259e+01, 1.36815252e+01,
         4.67479229e-04, 9.85731304e-01],
        [2.13319321e+01, 5.01298046e+00, 1.44537868e+01, 1.41183548e+01,
         4.15980816e-04, 9.85241652e-01],
        ...,
        [5.63437881e+01, 1.15099594e+02, 1.21474190e+02, 1.02216240e+02,
         6.14345074e-04, 9.69812930e-01],
        [8.74896545e+01, 1.16726608e+02, 1.10064934e+02, 8.91103439e+01,
         5.11735678e-04, 9.71362650e-01],
        [1.19493690e+02, 1.12301132e+02, 1.62882797e+02, 1.17599602e+02,
         5.25712967e-04, 9.74491775e-01]]], dtype=float32)]

tflite output @@@@@@@@@@@@@@@@@@@@@@@
elapsed time: 12.678861618041992ms
shape: (1, 1008, 6)
array([[[  5.280884,   5.280884,  10.561798,  10.561798,   0.      ,
           2.640442],
        [ 13.20224 ,   5.280884,  10.561798,  10.561798,   0.      ,
           2.640442],
        [ 21.123566,   5.280884,  10.561798,  10.561798,   0.      ,
           2.640442],
        ...,
        [ 47.528046, 113.5392  , 264.04465 , 232.35928 ,   0.      ,
           0.      ],
        [ 79.21338 , 113.5392  , 216.5166  , 208.59526 ,   0.      ,
           0.      ],
        [110.89874 , 113.5392  , 216.5166  , 216.5166  ,   0.      ,
           0.      ]]], dtype=float32)
PINTO0309 commented 2 years ago

Thank you for posting this thoughtful issue. You have made my investigation go very smoothly.

Some of the conversions had bugs, which have been fixed. https://github.com/PINTO0309/openvino2tensorflow/releases/tag/v1.26.1

There seems to be a slightly larger error, but even after looking at the entire structure of the model, I could not identify which part of the model had the problem.

If possible, try using the model cutting function in the link below to see which layers have larger errors. https://github.com/PINTO0309/openvino2tensorflow#6-10-ability-to-specify-an-output-layer-for-debugging-the-output-values-of-the-model

One area to pay special attention to during the transformation is whether the shape of the input tensor in Reshape is consistent with the shape of the output tensor after Reshape.

$INTEL_OPENVINO_DIR/deployment_tools/model_optimizer/mo.py \ --input_model shuffleNet_4d.onnx \ --data_type FP32 \ --output_dir saved_model/openvino/FP32

openvino2tensorflow \ --model_path saved_model/openvino/FP32/shuffleNet_4d.xml \ --output_saved_model \ --output_pb \ --output_no_quant_float32_tflite \ --weight_replacement_config replace.json

```python
import onnxruntime
import tensorflow as tf
import time
import numpy as np
from pprint import pprint

H=128
W=128

############################################################

onnx_session = onnxruntime.InferenceSession(f'shuffleNet_4d.onnx')
input_name = onnx_session.get_inputs()[0].name
output_name = onnx_session.get_outputs()[0].name

roop = 1
e = 0.0
result = None
inp = np.ones((1,1,H,W), dtype=np.float32)
for _ in range(roop):
    s = time.time()
    result = onnx_session.run(
        [output_name],
        {input_name: inp}
    )
    e += (time.time() - s)
print('ONNX output @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@')
print(f'elapsed time: {e/roop*1000}ms')
print(f'shape: {result[0].shape}')
pprint(result)

############################################################

interpreter = tf.lite.Interpreter(model_path=f'saved_model/model_float32.tflite', num_threads=4)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

roop = 1
e = 0.0
result = None
inp = np.ones((1,H,W,1), dtype=np.float32)
for _ in range(roop):
    s = time.time()
    interpreter.set_tensor(input_details[0]['index'], inp)
    interpreter.invoke()
    result = interpreter.get_tensor(output_details[1]['index'])
    e += (time.time() - s)
print('tflite output @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@')
print(f'elapsed time: {e/roop*1000}ms')
print(f'shape: {result.shape}')
pprint(result)
ONNX output @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
elapsed time: 0.5400180816650391ms
shape: (1, 1008, 6)
[array([[[5.26182938e+00, 4.99601841e+00, 1.22561550e+01, 1.36645145e+01,
         9.45538282e-04, 9.85830545e-01],
        [1.34348326e+01, 5.08123970e+00, 1.37564259e+01, 1.36815252e+01,
         4.67479229e-04, 9.85731304e-01],
        [2.13319321e+01, 5.01298046e+00, 1.44537868e+01, 1.41183548e+01,
         4.15980816e-04, 9.85241652e-01],
        ...,
        [5.63437881e+01, 1.15099594e+02, 1.21474190e+02, 1.02216240e+02,
         6.14345074e-04, 9.69812930e-01],
        [8.74896545e+01, 1.16726608e+02, 1.10064934e+02, 8.91103439e+01,
         5.11735678e-04, 9.71362650e-01],
        [1.19493690e+02, 1.12301132e+02, 1.62882797e+02, 1.17599602e+02,
         5.25712967e-04, 9.74491775e-01]]], dtype=float32)]
INFO: Created TensorFlow Lite XNNPACK delegate for CPU.
tflite output @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
elapsed time: 0.7603168487548828ms
shape: (1, 1008, 6)
array([[[5.28040695e+00, 5.00483131e+00, 1.22718182e+01, 1.36867762e+01,
         9.46503831e-04, 9.85802889e-01],
        [1.34492149e+01, 5.09168625e+00, 1.37784395e+01, 1.37132292e+01,
         4.68593993e-04, 9.85692441e-01],
        [2.13420620e+01, 5.02886963e+00, 1.44927731e+01, 1.41707439e+01,
         4.17113508e-04, 9.85178471e-01],
        ...,
        [5.66155014e+01, 1.14571739e+02, 1.16065849e+02, 9.84796066e+01,
         6.06960442e-04, 9.69192386e-01],
        [8.76312714e+01, 1.16028381e+02, 1.04266632e+02, 8.57446976e+01,
         5.12930972e-04, 9.70427155e-01],
        [1.19529739e+02, 1.11679428e+02, 1.58085449e+02, 1.16086975e+02,
         5.34444698e-04, 9.73940670e-01]]], dtype=float32)

model_float32 tflite (3)

PINTO0309 commented 2 years ago

I found that the MaxPool 5x5 part breaks the output. I will continue to investigate. I have already confirmed that the output matches perfectly until the last Swish.

tflite output @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ elapsed time: 0.4906654357910156ms shape: (1, 4, 4, 96) array([[[[-0.23088458, -0.24766795, -0.23495872, ..., 0. ,

  1. , 0. ], [-0.2584279 , -0.2293107 , -0.20997883, ..., 0. ,
  2. , 0. ], [-0.24309352, -0.24637869, -0.23182848, ..., 0. ,
  3. , 0. ], [-0.02394821, -0.27733138, -0.2782271 , ..., 0. ,
  4. , 0. ]],
    
    - OpenVINO MaxPool - Padding issues - OpenVINO is padded with something that is not zero.
    https://docs.openvino.ai/latest/openvino_docs_ops_pooling_MaxPool_1.html
PINTO0309 commented 2 years ago

MaxPool 5x5 bug fixes. https://github.com/PINTO0309/openvino2tensorflow/releases/tag/v1.26.2

$ python3 onnx_tflite_test.py 
ONNX output @@@@@@@@@@@@@@@@@@@@@@@@@
elapsed time: 0.5414485931396484ms
shape: (1, 1008, 6)
[array([[[5.26182938e+00, 4.99601841e+00, 1.22561550e+01, 1.36645145e+01,
         9.45538282e-04, 9.85830545e-01],
        [1.34348326e+01, 5.08123970e+00, 1.37564259e+01, 1.36815252e+01,
         4.67479229e-04, 9.85731304e-01],
        [2.13319321e+01, 5.01298046e+00, 1.44537868e+01, 1.41183548e+01,
         4.15980816e-04, 9.85241652e-01],
        ...,
        [5.63437881e+01, 1.15099594e+02, 1.21474190e+02, 1.02216240e+02,
         6.14345074e-04, 9.69812930e-01],
        [8.74896545e+01, 1.16726608e+02, 1.10064934e+02, 8.91103439e+01,
         5.11735678e-04, 9.71362650e-01],
        [1.19493690e+02, 1.12301132e+02, 1.62882797e+02, 1.17599602e+02,
         5.25712967e-04, 9.74491775e-01]]], dtype=float32)]
OpenVINO output @@@@@@@@@@@@@@@@@@@@@@@@@
elapsed time: 0.8075237274169922ms
shape: (1, 1008, 6)
array([[[5.26182938e+00, 4.99601746e+00, 1.22561550e+01, 1.36645145e+01,
         9.45528154e-04, 9.85830545e-01],
        [1.34348335e+01, 5.08123970e+00, 1.37564287e+01, 1.36815252e+01,
         4.67499776e-04, 9.85731304e-01],
        [2.13319302e+01, 5.01298237e+00, 1.44537868e+01, 1.41183548e+01,
         4.15943301e-04, 9.85241652e-01],
        ...,
        [5.63437881e+01, 1.15099609e+02, 1.21474289e+02, 1.02216316e+02,
         6.14331686e-04, 9.69812930e-01],
        [8.74896545e+01, 1.16726608e+02, 1.10065056e+02, 8.91103973e+01,
         5.11766935e-04, 9.71362591e-01],
        [1.19493698e+02, 1.12301140e+02, 1.62882706e+02, 1.17599617e+02,
         5.25735551e-04, 9.74491775e-01]]], dtype=float32)
INFO: Created TensorFlow Lite XNNPACK delegate for CPU.
tflite output @@@@@@@@@@@@@@@@@@@@@@@@@
elapsed time: 0.8258819580078125ms
shape: (1, 1008, 6)
array([[[5.26183033e+00, 4.99601746e+00, 1.22561550e+01, 1.36645145e+01,
         9.45528154e-04, 9.85830545e-01],
        [1.34348335e+01, 5.08123970e+00, 1.37564287e+01, 1.36815252e+01,
         4.67499573e-04, 9.85731304e-01],
        [2.13319302e+01, 5.01298141e+00, 1.44537868e+01, 1.41183577e+01,
         4.15943301e-04, 9.85241652e-01],
        ...,
        [5.63437881e+01, 1.15099609e+02, 1.21474289e+02, 1.02216316e+02,
         6.14331977e-04, 9.69812930e-01],
        [8.74896545e+01, 1.16726601e+02, 1.10065056e+02, 8.91103973e+01,
         5.11766644e-04, 9.71362591e-01],
        [1.19493690e+02, 1.12301140e+02, 1.62882812e+02, 1.17599670e+02,
         5.25734213e-04, 9.74491835e-01]]], dtype=float32)

Screenshot 2021-12-04 02:23:23

beekeeper23 commented 2 years ago

Thank you so much! Is there also any way to avoid the MirrorPad and replace it with Pad layers?

PINTO0309 commented 2 years ago

If the value of Padding is greater than or equal to 2, you must interpolate the value with REFLECT or SYMMETRIC. Please check the following specifications.

  1. ONNX: https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Pad-11
  2. OpenVINO: https://docs.openvino.ai/latest/openvino_docs_ops_movement_Pad_1.html
  3. TensorFlow: https://www.tensorflow.org/api_docs/python/tf/pad

Therefore, if you want to avoid generating MirrorPad layers, your only options are to use padding by CONSTANT value or zero padding with a padding size less than or equal 1 to reduce the error. In other words, you should change the structure of your successive MaxPool 5x5 layers.

PINTO0309 commented 2 years ago

The first issue has been resolved, and I'm closing this once since there hasn't been a reply to my last post for a while.