eragonruan / text-detection-ctpn

text detection mainly based on ctpn model in tensorflow, id card detect, connectionist text proposal network
MIT License
3.43k stars 1.34k forks source link

Model Quantization #490

Open leeshien opened 2 years ago

leeshien commented 2 years ago

The pretrained model is favorable as it detects text with high accuracy. However, it takes long time to inference with CPU, this is out of expectation in production deployment. Is there a way to quantize the model?

What I've tried is to convert the checkpoint model to saved_model format, then load from saved_model to perform quantization with TFLite converter, code snippet as followed:

# Load checkpoint and convert to saved_model
import tf
trained_checkpoint_prefix = "checkpoints_mlt/ctpn_50000.ckpt"
export_dir = "exported_model"

graph = tf.Graph()
with tf.compat.v1.Session(graph=graph) as sess:
    # Restore from checkpoint
    loader = tf.compat.v1.train.import_meta_graph(trained_checkpoint_prefix + ".meta")
    loader.restore(sess, trained_checkpoint_prefix)

# Export checkpoint to SavedModel
builder = tf.compat.v1.saved_model.builder.SavedModelBuilder(export_dir)
builder.add_meta_graph_and_variables(sess,
                                     [tf.saved_model.TRAINING, tf.saved_model.SERVING],
                                     strip_default_attrs=True)
builder.save()

In result, I got a .pb file and and a variables folder with checkpoint and index files inside. Then errors popped out when I tried to perform quantization:

converter = tf.lite.TFLiteConverter.from_saved_model(export_dir)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.int8  # or tf.uint8
converter.inference_output_type = tf.int8  # or tf.uint8
tflite_quant_model = converter.convert()

This is the error:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-4-03205673177f> in <module>
     11 converter.inference_input_type = tf.int8  # or tf.uint8
     12 converter.inference_output_type = tf.int8  # or tf.uint8
---> 13 tflite_quant_model = converter.convert()

~/virtualenvironment/tf2/lib/python3.6/site-packages/tensorflow/lite/python/lite.py in convert(self)
    450     # TODO(b/130297984): Add support for converting multiple function.
    451     if len(self._funcs) != 1:
--> 452       raise ValueError("This converter can only convert a single "
    453                        "ConcreteFunction. Converting multiple functions is "
    454                        "under development.")

ValueError: This converter can only convert a single ConcreteFunction. Converting multiple functions is under development.

Understand that this error was raised due to the multiple inputs input_image and input_im_info required by the model. Appreciate if anyone could help.