VeriSilicon / tflite-vx-delegate

Tensorflow Lite external delegate based on TIM-VX
MIT License
41 stars 23 forks source link

Convolution2DTransposeBias support #165

Open vsi-minggang opened 1 year ago

vsi-minggang commented 1 year ago

I am filing this jira separately and closing the originating jira. Mediapipe framework is very popular and has lots of useful models. A lot of these models use the Convolution2DTransposeBias custom operation. Please evaluate if it would be possible to add support for this operation. It is effectively a transposed convolution layer + add bias implemented in one operation. You can find more details in this post: https://github.com/google/mediapipe/issues/245

The operation is supported by tflite for a while in the GPU delegate and the XNNPACK delegate. You may even have this operation supported already but just need a code to recognize it and to convert it into acuity operation.

The GPU implementation in tflite is here: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/delegates/gpu/gl/kernels/transpose_conv.cc

The XNNPACK implementation is here: https://github.com/google/XNNPACK/blob/master/src/subgraph/deconvolution-2d.c#L356

It is effectively a transposed convolution layer + add bias implemented in one operation. You can find more details in this post: https://github.com/google/mediapipe/issues/245

the XNNPACK delegate implements the operation but it does not register it for some reason. You need to add a following code to your application:

/registration function/ TfLiteRegistration RegisterConvolution2DTransposeBias() { static TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr}; return ® } ... /somewhere down below where resolver is created*/

// Build the interpreter with the InterpreterBuilder. // Note: all Interpreters should be built with the InterpreterBuilder, // which allocates memory for the Intrepter and does various set up // tasks so that the Interpreter can read the provided model. tflite::ops::builtin::BuiltinOpResolver resolver; /must add custom op for it to be resolved/ resolver.AddCustom("Convolution2DTransposeBias", RegisterConvolution2DTransposeBias());

tflite::InterpreterBuilder builder(*model, resolver);