PINTO0309 / PINTO_model_zoo

A repository for storing models that have been inter-converted between various frameworks. Supported frameworks are TensorFlow, PyTorch, ONNX, OpenVINO, TFJS, TFTRT, TensorFlowLite (Float32/16/INT8), EdgeTPU, CoreML.
https://qiita.com/PINTO
MIT License
3.51k stars 566 forks source link

Request: FaceMesh-with-Attention model conversion (unsupported custom ops) #143

Closed vladmandic closed 2 years ago

vladmandic commented 2 years ago

MediaPipe has released a new FaceMesh-with-Attention model

That's basically an old FaceMesh model augmented with 3 additional new attention models that refine results, all inside single TFlite model:

I've tried converting it:

tflite2tensorflow --model_path face_landmark_with_attention.tflite --flatc_path ./flatc --schema_path schema.fbs --output_pb

but it fails with

RuntimeError: Encountered unresolved custom op: Landmarks2TransformMatrix.Node number 192 (Landmarks2TransformMatrix) failed to prepare.

It seems that TFLite model is using custom ops to link different execution paths inside it - that is beyond me...

PINTO0309 commented 2 years ago

I am looking forward to receiving proposals that represent the following three types of processing with standard TensorFlow operations.

  1. transform_landmarks https://github.com/google/mediapipe/blob/master/mediapipe/util/tflite/operations/transform_landmarks.cc

  2. transform_tensor_bilinear https://github.com/google/mediapipe/blob/master/mediapipe/util/tflite/operations/transform_tensor_bilinear.cc

  3. landmarks_to_transform_matrix https://github.com/google/mediapipe/blob/master/mediapipe/util/tflite/operations/landmarks_to_transform_matrix.cc

vladmandic commented 2 years ago

@PINTO0309 thanks for looking at this!

imo, matching core functionality is not the thoughest part - its matching input and output signatures as they are very C-oriented
and testing this is a nightmare

i don't know why mediapipe is going in a direction of more and more proprietary C code instead of clean models and contributing to TF instead

all-in-all, i'm not sure if its worth it...

PINTO0309 commented 2 years ago

I agree with you. Therefore, I have already decided that this model is not worth the price of the effort to generalize it.

vladmandic commented 2 years ago

i though so too - thanks for confirming. i'm closing this request.

mayerjTNG commented 2 years ago

Hi, first of all, thanks @PINTO0309 for all the work you've done for the community, you've been a real life saviour more than once to me! :) I'm currently attempting to build something similar to mediapipe holistic in pure tensorflow so that it will run properly on desktop GPUs. During this, i've done a bunch of reverse engineering of mediapipe functionality and reimplemented a bunch of stuff that comes awfully close to the missing operations in question. I've also found that the standard face mesh w/o attention just won't do the cut for me in terms of quality. I found this issue a couple of weeks ago while trying to convert the tflite model. While this is obviously not great news for me, i was still determined to get it to work. First, to compile tflite with the custom ops and then write custom layers for the ops and do trial and error until it finally converts properly. I followed your tutorial trying to add the three layers to tflite. I've been trying for a bit over a week now, getting the ops to compile and registering them to the runtime but no matter how i tried to achieve this, the final .whl still seems to be missing the custom ops. The issue here is that i don't really know too much about what i'm doing and i'm starting to think about tossing the towel on this one, failing to do the supposed "easy" part of integrating existing layers from mediapipe back into tensorflow. So my question is: Is there any chance, that you could compile a version of tensorflow/tfliteruntime that supports the custom ops? I recon that you would know quite a lot about all of this at this point. :) I'd be more than happy to take it from there and file a PR with the custom layer implementations in tensorflow. To show you that i'm not bullshitting, here's a draft for the layers:

transform_landmarks

    def transform_landmarks_2d(landmarks, transformation):
        landmarks_xy, landmarks_residual = landmarks[..., :2], landmarks[..., 2:]
        landmarks_xyw = tf.pad(landmarks_xy, [[0, 0], [0, 0], [0, 1]], constant_values=1)

        number_of_landmarks = tf.shape(landmarks_xyw)[-2]
        broadcasted_matrix = tf.repeat(tf.expand_dims(transformation, axis=1), number_of_landmarks, axis=1)

        transformed_landmarks_xyw = tf.reshape(tf.matmul(broadcasted_matrix, tf.expand_dims(landmarks_xyw, axis=-1)),
                                               (-1, number_of_landmarks, 3))

        transformed_landmarks_xy = transformed_landmarks_xyw[..., :2]
        return tf.concat([transformed_landmarks_xy, landmarks_residual], axis=-1)

transform_tensor_bilinear (using tensorflow-addons)

    def crop_image(self, image, resolution):
            crop_transformation = tfa.image.transform_ops.matrices_to_flat_transforms(
                tf.linalg.inv(self.get_crop_matrix(resolution)))
            batch_dimension = tf.shape(crop_transformation)[:-1]

            broadcasted_image = tf.broadcast_to(image,
                                                tf.concat([batch_dimension, tf.shape(image)[-3:]], axis=-1))

            cropped_image = tfa.image.transform(
                broadcasted_image,
                crop_transformation,
                interpolation="BILINEAR",
                output_shape=resolution,
                fill_mode="nearest",
            )
            return cropped_image

landmarks_to_transform_matrix This one is interesting, it first estimates an AA bounding box of the landmarks, the rotates it and applies some scale to it and computes a corresponding transformation matrix from it. I have a couple of 100locs that do all of this but i don't really see the benefit of posting it here in its rudimentary state.

KenjiAsaba commented 2 years ago

Hi @mayerjTNG , I' ve been working on the same topic for a while, and just yesterday I managed to compile tflite with the custom operators.

My source code of tensorflow can be found in this branch: https://github.com/KenjiAsaba/tensorflow/tree/mediapipe_20220320_customOp My compiled tflite can be found here: https://github.com/KenjiAsaba/tensorflow/releases/tag/v2.8.0-withMediaPipeCustomOp

PINTO0309 commented 2 years ago

@mayerjTNG @KenjiAsaba Thanks so much for all your efforts, you guys are very much appreciated. I will first try to compile a custom OP for TensorFlow Lite by merging it into a .whl.

https://github.com/PINTO0309/TensorflowLite-bin#2-tensorflow-v230-version-or-later https://github.com/PINTO0309/Tensorflow-bin#build-parameter

After successfully building the .whl, I will read the .tflite of the MediaPipe containing the custom OP with tflite2tensorflow and try to convert it to a semi-standard OP. :smile:

PINTO0309 commented 2 years ago

@mayerjTNG First I tried to define the layers based on the logic you suggested to test each component. I tried to convert transform_landmarks located at the end of the model structure to saved_model and got a shape mismatch error. I'm thinking that setting 1 for all tensors in np.ones may not be a good idea, can you provide a sample of what specific values you are expecting?

dummy_input1 = np.ones([1,80,2], dtype=np.float32) dummy_input2 = np.ones([1,4,4], dtype=np.float32)

Create a model

i1 = tf.keras.layers.Input( shape=[ dummy_input1.shape[1], dummy_input1.shape[2], ], batch_size=dummy_input1.shape[0], dtype=tf.float32, ) i2 = tf.keras.layers.Input( shape=[ dummy_input2.shape[1], dummy_input2.shape[2], ], batch_size=dummy_input2.shape[0], dtype=tf.float32, )

def transform_landmarks_2d(landmarks, transformation): landmarks_xy, landmarks_residual = landmarks[..., :2], landmarks[..., 2:] landmarks_xyw = tf.pad(landmarks_xy, [[0, 0], [0, 0], [0, 1]], constant_values=1) number_of_landmarks = tf.shape(landmarks_xyw)[-2] broadcasted_matrix = tf.repeat( tf.expand_dims(transformation, axis=1), number_of_landmarks, axis=1 ) transformed_landmarks_xyw = tf.reshape( tf.matmul( broadcasted_matrix, tf.expand_dims(landmarks_xyw, axis=-1) ), (-1, number_of_landmarks, 3) ) transformed_landmarks_xy = transformed_landmarks_xyw[..., :2] return tf.concat([transformed_landmarks_xy, landmarks_residual], axis=-1)

o = transform_landmarks_2d(i1,i2)

model = tf.keras.models.Model(inputs=[i1,i2], outputs=[o]) model.summary() output_path = 'saved_model' tf.saved_model.save(model, output_path) converter = tf.lite.TFLiteConverter.from_keras_model(model) converter.target_spec.supported_ops = [ tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS ] tflite_model = converter.convert() open(f"{output_path}/test.tflite", "wb").write(tflite_model)

To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags. Traceback (most recent call last): File "test.py", line 51, in o = transform_landmarks_2d(i1,i2) File "test.py", line 42, in transform_landmarks_2d tf.matmul( File "/usr/local/lib/python3.8/dist-packages/tensorflow/python/util/traceback_utils.py", line 153, in error_handler raise e.with_traceback(filtered_tb) from None File "/usr/local/lib/python3.8/dist-packages/keras/layers/core/tf_op_layer.py", line 107, in handle return TFOpLambda(op)(*args, **kwargs) File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 67, in error_handler raise e.with_traceback(filtered_tb) from None ValueError: Exception encountered when calling layer "tf.linalg.matmul" (type TFOpLambda).

Dimensions must be equal, but are 4 and 3 for '{{node tf.linalg.matmul/MatMul}} = BatchMatMulV2[T=DT_FLOAT, adj_x=false, adj_y=false](Placeholder, Placeholder_1)' with input shapes: [1,80,4,4], [1,80,3,1].

Call arguments received: • a=tf.Tensor(shape=(1, 80, 4, 4), dtype=float32) • b=tf.Tensor(shape=(1, 80, 3, 1), dtype=float32) • transpose_a=False • transpose_b=False • adjoint_a=False • adjoint_b=False • a_is_sparse=False • b_is_sparse=False • output_type=None • name=None

mayerjTNG commented 2 years ago

Hi @PINTO0309, the issue here seems to be that my implementation expects a 3x3 (2D) transformation while the mediapipe implementation uses 4x4 (3D) transformation matrices (most likely for compatibility). In the original implementation they then discard the last two rows and do the matrix multiplication via individual dot products. A work-around for this is to slice away the 3rd dimension and ignore sheering transformations like this:

def transform_landmarks_2d(landmarks, transformation_3d):

    transformation_2d = transformation_3d[...,:2,:3]

    landmarks_xy, landmarks_residual = landmarks[..., :2], landmarks[..., 2:]
    landmarks_xyw = tf.pad(landmarks_xy, [[0, 0], [0, 0], [0, 1]], constant_values=1)

    number_of_landmarks = tf.shape(landmarks_xy)[-2]
    broadcasted_matrix = tf.repeat(
        tf.expand_dims(transformation_2d, axis=1),
        number_of_landmarks, axis=1
    )
    transformed_landmarks_xyw = tf.reshape(
        tf.matmul(
            broadcasted_matrix,
            tf.expand_dims(landmarks_xyw, axis=-1)
        ),
        (-1, number_of_landmarks, 2)
    )
    transformed_landmarks_xy = transformed_landmarks_xyw[..., :2]
    return tf.concat([transformed_landmarks_xy, landmarks_residual], axis=-1)

Using the function above gives me following output:

$ python test.py 
2022-03-25 14:42:05.424769: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 input_1 (InputLayer)           [(1, 80, 2)]         0           []                               

 input_2 (InputLayer)           [(1, 4, 4)]          0           []                               

 tf.__operators__.getitem_1 (Sl  (1, 80, 2)          0           ['input_1[0][0]']                
 icingOpLambda)                                                                                   

 tf.__operators__.getitem (Slic  (1, 2, 3)           0           ['input_2[0][0]']                
 ingOpLambda)                                                                                     

 tf.compat.v1.shape (TFOpLambda  (3,)                0           ['tf.__operators__.getitem_1[0][0
 )                                                               ]']                              

 tf.expand_dims (TFOpLambda)    (1, 1, 2, 3)         0           ['tf.__operators__.getitem[0][0]'
                                                                 ]                                

 tf.__operators__.getitem_3 (Sl  ()                  0           ['tf.compat.v1.shape[0][0]']     
 icingOpLambda)                                                                                   

 tf.compat.v1.pad (TFOpLambda)  (1, 80, 3)           0           ['tf.__operators__.getitem_1[0][0
                                                                 ]']                              

 tf.repeat (TFOpLambda)         (1, 80, 2, 3)        0           ['tf.expand_dims[0][0]',         
                                                                  'tf.__operators__.getitem_3[0][0
                                                                 ]']                              

 tf.expand_dims_1 (TFOpLambda)  (1, 80, 3, 1)        0           ['tf.compat.v1.pad[0][0]']       

 tf.linalg.matmul (TFOpLambda)  (1, 80, 2, 1)        0           ['tf.repeat[0][0]',              
                                                                  'tf.expand_dims_1[0][0]']       

 tf.reshape (TFOpLambda)        (1, 80, 2)           0           ['tf.linalg.matmul[0][0]',       
                                                                  'tf.__operators__.getitem_3[0][0
                                                                 ]']                              

 tf.__operators__.getitem_4 (Sl  (1, 80, 2)          0           ['tf.reshape[0][0]']             
 icingOpLambda)                                                                                   

 tf.__operators__.getitem_2 (Sl  (1, 80, 0)          0           ['input_1[0][0]']                
 icingOpLambda)                                                                                   

 tf.concat (TFOpLambda)         (1, 80, 2)           0           ['tf.__operators__.getitem_4[0][0
                                                                 ]',                              
                                                                  'tf.__operators__.getitem_2[0][0
                                                                 ]']                              

==================================================================================================
Total params: 0
Trainable params: 0
Non-trainable params: 0
__________________________________________________________________________________________________

Which looks about right. This solution still has a lot of room for improvement though.

As mentioned previously, my suggestions were pretty much taken from our current code base with no attempt to adjust it to the exact input specs of the mediapipe ops just yet. I also honestly didn't expect you guys to come up with a solution this quick, so again, you are the best. :) I'll happily whip up some clean implementations for the three missing layers but unfortunately i'm afk from my proper workstation for the weekend and won't be able to work properly on this until monday. :)

PINTO0309 commented 2 years ago

@mayerjTNG Thanks! It is wonderful. It certainly worked! test.tflite.zip test tflite

I believe that the processing changes between 2D and 3D can be toggled by checking the size of each dimension. Rather than forcing a single method to handle all processing, it is easier to simply create two or three types of functions and branch processing according to the shape of the input tensor. However, I like clean codes.

I believe the only end goal we must not forget is to generate a clean TensorFlow model. :smile_cat: I believe a runtime build using @KenjiAsaba's custom OP combined with your workaround could work.

All I have to do after all the work is to make sure that the values output by KenjiAsaba's runtime match the values output by the OP you suggested. :+1:

mayerjTNG commented 2 years ago

@PINTO0309 Yeah, i agree. I think we won't profit from over-engineering here. I'll make some minimal and clean functions that replicate the exact functionality of the tflite-ops. More on that on monday though. The plan sounds good. I think, we're onto something here. :) Take care!

KenjiAsaba commented 2 years ago

Hi @PINTO0309, @mayerjTNG I am also trying to implement the custom ops. It is still a work in progress, but here is my code just for your information. So far, I have successfully output saved_model.pb and onnx but not tested them yet.

PINTO0309 commented 2 years ago

I also successfully loaded face_landmark_with_attention.tflite.  image

The process of applying the MediaPipe patch is a bit laborious. :sweat_drops:

PINTO0309 commented 2 years ago

I have updated only Wheel in TensorFlow Lite to the latest and committed. I have not yet implemented the tflite2tensorflow script update.

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
from tflite_runtime.interpreter import Interpreter
import numpy as np
from pprint import pprint
interpreter = Interpreter('face_landmark_with_attention.tflite')
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
dummy_input = np.ones([1,192,192,3], dtype=np.float32)
interpreter.set_tensor(input_details[0]['index'], dummy_input)
interpreter.invoke()
ret1 = interpreter.get_tensor(output_details[0]['index'])
ret2 = interpreter.get_tensor(output_details[1]['index'])
ret3 = interpreter.get_tensor(output_details[2]['index'])
ret4 = interpreter.get_tensor(output_details[3]['index'])
ret5 = interpreter.get_tensor(output_details[4]['index'])
ret6 = interpreter.get_tensor(output_details[5]['index'])
ret7 = interpreter.get_tensor(output_details[6]['index'])
pprint(ret1.shape)
pprint(ret2.shape)
pprint(ret3.shape)
pprint(ret4.shape)
pprint(ret5.shape)
pprint(ret6.shape)
pprint(ret7.shape)
pprint(ret1)

"""
(1, 1, 1, 1404)
(1, 1, 1, 160)
(1, 1, 1, 142)
(1, 1, 1, 142)
(1, 1, 1, 10)
(1, 1, 1, 10)
(1, 1, 1, 1)
array(
    [
        [
            [
                [
                    92.86058  , 122.887276 , -12.578423 , ..., 138.47002  , 71.958855 ,   2.5877006
                ]
            ]
        ]
    ],
    dtype=float32)
"""
KenjiAsaba commented 2 years ago

I completed the inplementation of the custom ops. Here is a demo video. It works, but the performance is not very good... Inference took 50ms by CUDA on RTX3070 using ONNX Runtime.

PINTO0309 commented 2 years ago

I read the MediaPipe .tflite with the TFLite runtime and incorporated KenjiAsaba's logic into tflite2tensorflow to generate the reverse-transformed .tflite (float32).

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
import tensorflow as tf
import numpy as np
from pprint import pprint

# original
interpreter = tf.lite.Interpreter('face_landmark_with_attention.tflite')
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
dummy_input = np.ones([1,192,192,3], dtype=np.float32)
interpreter.set_tensor(input_details[0]['index'], dummy_input)
interpreter.invoke()
ret1 = interpreter.get_tensor(output_details[0]['index'])
ret2 = interpreter.get_tensor(output_details[1]['index'])
ret3 = interpreter.get_tensor(output_details[2]['index'])
ret4 = interpreter.get_tensor(output_details[3]['index'])
ret5 = interpreter.get_tensor(output_details[4]['index'])
ret6 = interpreter.get_tensor(output_details[5]['index'])
ret7 = interpreter.get_tensor(output_details[6]['index'])
print('@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ Original')
pprint(ret1.shape)
pprint(ret2.shape)
pprint(ret3.shape)
pprint(ret4.shape)
pprint(ret5.shape)
pprint(ret6.shape)
pprint(ret7.shape)
pprint(ret1)

# After reverse transformation
interpreter = tf.lite.Interpreter('model_float32.tflite')
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
dummy_input = np.ones([1,192,192,3], dtype=np.float32)
interpreter.set_tensor(input_details[0]['index'], dummy_input)
interpreter.invoke()
ret1 = interpreter.get_tensor(output_details[0]['index'])
ret2 = interpreter.get_tensor(output_details[1]['index'])
ret3 = interpreter.get_tensor(output_details[2]['index'])
ret4 = interpreter.get_tensor(output_details[3]['index'])
ret5 = interpreter.get_tensor(output_details[4]['index'])
ret6 = interpreter.get_tensor(output_details[5]['index'])
ret7 = interpreter.get_tensor(output_details[6]['index'])
print('@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ After reverse transformation')
pprint(ret1.shape)
pprint(ret2.shape)
pprint(ret3.shape)
pprint(ret4.shape)
pprint(ret5.shape)
pprint(ret6.shape)
pprint(ret7.shape)
pprint(ret4)
@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ Original
(1, 1, 1, 1404)
(1, 1, 1, 160)
(1, 1, 1, 142)
(1, 1, 1, 142)
(1, 1, 1, 10)
(1, 1, 1, 10)
(1, 1, 1, 1)
array([[[[ 92.86058  , 122.887276 , -12.578423 , ..., 138.47002  ,
           71.958855 ,   2.5877006]]]], dtype=float32)

@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ After reverse transformation
(1, 1, 1, 142)
(1, 1, 1, 142)
(1, 1, 1, 1)
(1, 1, 1, 1404)
(1, 1, 1, 10)
(1, 1, 1, 160)
(1, 1, 1, 10)
array([[[[ 92.86058  , 122.887276 , -12.578423 , ..., 138.47002  ,
           71.958855 ,   2.5877006]]]], dtype=float32)
PINTO0309 commented 2 years ago

Thank you. @KenjiAsaba I will also convert to ONNX and measure the performance. By the way, the sample code you shared has a "Akiya Research Institute, Inc." rights notice. If I am allowed to incorporate your valuable source code into my tool, how should I include your rights notice? Or should I refrain from including it in the tool?

I have made a few adjustments to your source code.

e.g. tf.add -> tf.math.add e.g. tf.floor -> tf.math.floor e.g. All of mediapipeCustomOp.py has been imported into the body of tflite2tensorflow.py.

If, by any chance, the rights do not permit me to modify the program and quote it, I would appreciate it if you would let me know.

PINTO0309 commented 2 years ago

@KenjiAsaba I think this model is fast enough.

array([[[[ 92.86058  , 122.887276 , -12.578423 , ..., 138.47002  ,
           71.958855 ,   2.5877006]]]], dtype=float32)
elapsed_time avg: 3.879547119140625 ms
import onnxruntime
import time
model_path = 'saved_model/model_float32.onnx'
model_file_name = model_path.split(".")[0]
session_option = onnxruntime.SessionOptions()
session_option.optimized_model_filepath = f"{model_file_name}_cudaopt.onnx"
session_option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_EXTENDED
session = onnxruntime.InferenceSession(
    model_path,
    session_option,
    providers=['CUDAExecutionProvider']
)
input_name = session.get_inputs()[0].name
output_names = [o.name for o in session.get_outputs()]
input_shape = session.get_inputs()[0].shape
# Warmup
output = session.run(
    output_names,
    {input_name: dummy_input.transpose((0,3,1,2))}
)
# Inference
start = time.time()
ITER=10
for i in range(ITER):
    output = session.run(
        output_names,
        {input_name: dummy_input.transpose((0,3,1,2))}
    )
print(f'elapsed_time avg: {(time.time()-start)/ITER*1000} ms')
PINTO0309 commented 2 years ago

BTW, ONNX 1.11.0 + TensorRT 8.2.3 (RTX 3070)

array([[[[ 92.86058  , 122.887276 , -12.578423 , ..., 138.47002  ,
           71.958855 ,   2.5877006]]]], dtype=float32)
elapsed_time avg: 1.6012907028198242 ms
import onnxruntime
import time
model_path = 'saved_model/model_float32.onnx'
model_file_name = model_path.split(".")[0]
session_option = onnxruntime.SessionOptions()
# session_option.optimized_model_filepath = f"{model_file_name}_cudaopt.onnx"
# session_option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_EXTENDED
session_option.log_severity_level = 3
session = onnxruntime.InferenceSession(
    model_path,
    session_option,
    providers = [
        (
            'TensorrtExecutionProvider', {
                'trt_engine_cache_enable': True,
                'trt_engine_cache_path': model_path.split("/")[0],
                'trt_fp16_enable': True,
            }
        ),
    ]
)
input_name = session.get_inputs()[0].name
output_names = [o.name for o in session.get_outputs()]
input_shape = session.get_inputs()[0].shape
# Warmup
output = session.run(
    output_names,
    {input_name: dummy_input.transpose((0,3,1,2))}
)
# Inference
start = time.time()
ITER=10
for i in range(ITER):
    output = session.run(
        output_names,
        {input_name: dummy_input.transpose((0,3,1,2))}
    )
print(f'elapsed_time avg: {(time.time()-start)/ITER*1000} ms')
PINTO0309 commented 2 years ago

BTW, ONNX 1.11.0 + OpenVINO Execution Provider + CPU Intel(R) Core(TM) i9-10900K CPU @ 3.70GHz

array([[[[ 92.86058  , 122.887276 , -12.578423 , ..., 138.47002  ,
           71.958855 ,   2.5877006]]]], dtype=float32)
elapsed_time avg: 6.2346696853637695 ms
import onnxruntime
import time
model_path = 'saved_model/model_float32.onnx'
model_file_name = model_path.split(".")[0]
session_option = onnxruntime.SessionOptions()
# session_option.optimized_model_filepath = f"{model_file_name}_cudaopt.onnx"
# session_option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_EXTENDED
session_option.log_severity_level = 3
session_option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL
onnxruntime.capi._pybind_state.set_openvino_device('CPU_FP32')
session = onnxruntime.InferenceSession(
    model_path,
    session_option,
    providers = [
        # (
        #     'TensorrtExecutionProvider', {
        #         'trt_engine_cache_enable': True,
        #         'trt_engine_cache_path': model_path.split("/")[0],
        #         'trt_fp16_enable': True,
        #     }
        # ),
        "OpenVINOExecutionProvider",
        # "CPUExecutionProvider"
    ]
)
input_name = session.get_inputs()[0].name
output_names = [o.name for o in session.get_outputs()]
input_shape = session.get_inputs()[0].shape
# Warmup
output = session.run(
    output_names,
    {input_name: dummy_input.transpose((0,3,1,2))}
)
# Inference
start = time.time()
ITER=10
for i in range(ITER):
    output = session.run(
        output_names,
        {input_name: dummy_input.transpose((0,3,1,2))}
    )
print(f'elapsed_time avg: {(time.time()-start)/ITER*1000} ms')
PINTO0309 commented 2 years ago

@vladmandic @KenjiAsaba @mayerjTNG The tflite2tensorflow logic fix is pending, but I have committed the converted and optimized model here. If there are any problems I will always fix them. If you have any problems please contact me.

KenjiAsaba commented 2 years ago

Thank you for the performance test. It seems just my environment has a problem 😃

I added a license notice to my code. It's MIT. So, please merge my code to your tool. It would be a great honor for me!

Thank you also for adding optimised models to the Zoo. I would be grateful if you could add the above licence notice here as well.

PINTO0309 commented 2 years ago

@KenjiAsaba @mayerjTNG Thank you for your cooperation. All have been merged into the main branch. :+1:

https://github.com/PINTO0309/tflite2tensorflow

mayerjTNG commented 2 years ago

Just tested the model, it runs like a charm! Love it! Great work guys! :) Again thank you very much for your help, I hope i'll be able to be of more use and repay the favour next time! Take care! image

PINTO0309 commented 2 years ago

https://github.com/iwatake2222/play_with_tflite/tree/master/pj_tflite_face_landmark_with_attention image

wwdok commented 2 years ago

Hello, @KenjiAsaba , I am going to modify the TransformTensorBilinear and Landmarks2TransformMatrix a little bit, because I want to deploy face_landmark_with_attention.tflite to ncnn framework, but ncnn does not support gather and gather_nd operators (the operators that ncnn supports), so I decide to replace them, then I find in your test_mediapipeCustomOp.py , there is a actual.bmp, could you please tell me what is the actual.bmp, what is its content and size, I am going to use it for test later on myself. Hello, @PINTO0309 , I also try to use the tflite2tensorflow to convert orginal face_landmark_with_attention.tflite to tf pb model, but the terminal reports RuntimeError: Encountered unresolved custom op: Landmarks2TransformMatrix.Node number 192 (Landmarks2TransformMatrix) failed to prepare., does this means I need to install custom build tensorflow instead of offcial tensorflow, like this one or this one ? BTW, if you have any tips, please tell me, I am fresh to model conversion and tf, thanks in advance !

PINTO0309 commented 2 years ago

First, read the README.

https://github.com/PINTO0309/tflite2tensorflow#3-1-environment-construction-pattern-1-execution-by-docker-strongly-recommended

https://github.com/PINTO0309/tflite2tensorflow/releases/tag/v1.20.8

docker run -it --rm \
-v `pwd`:/home/user/workdir \
ghcr.io/pinto0309/tflite2tensorflow:latest
KenjiAsaba commented 2 years ago

Hi, @wwdok "actual.bmp" is a 192x192 pixel color image I used as input to the original model to generate the test landmark data at line 49. For your test, you can use any 192x192 pixel image, and compare the result before and after your modification.

vladmandic commented 2 years ago

fyi, i've just added facemesh attention to https://vladmandic.github.io/human,
works nice but it is ~%25 slower than facemesh + iris models combined

i've also added keypoint mapping of new attention keypoints back to original mesh keypoints
plus remapping of z-coord since augmented data is 2d only
(in https://github.com/vladmandic/human/blob/main/src/face/attention.ts

anyhow, results from https://github.com/vladmandic/human-motion image

qhanson commented 2 years ago

@KenjiAsaba I think this model is fast enough.

array([[[[ 92.86058  , 122.887276 , -12.578423 , ..., 138.47002  ,
           71.958855 ,   2.5877006]]]], dtype=float32)
elapsed_time avg: 3.879547119140625 ms
import onnxruntime
import time
model_path = 'saved_model/model_float32.onnx'
model_file_name = model_path.split(".")[0]
session_option = onnxruntime.SessionOptions()
session_option.optimized_model_filepath = f"{model_file_name}_cudaopt.onnx"
session_option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_EXTENDED
session = onnxruntime.InferenceSession(
    model_path,
    session_option,
    providers=['CUDAExecutionProvider']
)
input_name = session.get_inputs()[0].name
output_names = [o.name for o in session.get_outputs()]
input_shape = session.get_inputs()[0].shape
# Warmup
output = session.run(
    output_names,
    {input_name: dummy_input.transpose((0,3,1,2))}
)
# Inference
start = time.time()
ITER=10
for i in range(ITER):
    output = session.run(
        output_names,
        {input_name: dummy_input.transpose((0,3,1,2))}
    )
print(f'elapsed_time avg: {(time.time()-start)/ITER*1000} ms')

The refined landmarks of eyes, iris, and lips are a little different with the original mediapipe v0.8.9. For example,

compare, idx-33, 
0.4246020269255817 0.4250536262989044  # x
0.396430768219194 0.3977607190608978  # y 
0.05556639647457429 0.05556516349315643 # z

The absolute error should be less 1e-4 if we reverse them 100% equally.

PINTO0309 commented 2 years ago

Pull requests are welcome.