tensorflow / tfjs

A WebGL accelerated JavaScript library for training and deploying ML models.
Apache License 2.0
18.37k stars 1.92k forks source link

tf-js tflite: use_regular_nms=True for SSD model leads to a freeze #8124

Open Jove125 opened 8 months ago

Jove125 commented 8 months ago


The SSD model converted to tflite with a flag use_regular_nms=True doesnt work with tf-js tflite. Calling the model.predict function causes the web-page to freeze without returning any error or result. The same model works well on Android and PC/Python.

The same SSD model converted to tflite with a flag use_regular_nms=False work fine with tf-js tflite (but this model leads to a large number of false positives, so it is not suitable).

Is this an error or is the functionality not implemented? Is there any way to solve the problem?

gaikwadrahul8 commented 8 months ago

Hi, @Jove125

Thank you for bringing this issue to our attention and could you please help us with your Github repo or code snippet to replicate the same behavior from our end. Thank you for your understanding and patience.

Jove125 commented 8 months ago

Hi, @gaikwadrahul8

See code below. I can export and attach both tflite's if you need: with use_regular_nms=True and use_regular_nms=False

import tensorflow as tf
import numpy as np

import os
import argparse
import json
import glob
import sys

from keras.models import load_model
from model.ModelBuilder import ModelBuilder
from utils_train.Encoder import AnchorBox
import functools

_DETECTION_POSTPROCESS_FUNC = 'TFLite_Detection_PostProcess'
class SSDModule(tf.Module):
    """Inference Module for TFLite-friendly SSD models."""
    def __init__(self, config, detection_model, max_detections=4, use_regular_nms=False):

          pipeline_config: The original pipeline_pb2.TrainEvalPipelineConfig
          detection_model: The detection model to use for inference.
          max_detections: Max detections desired from the TFLite model.
          use_regular_nms: If True, TFLite model uses the (slower) multi-class NMS.
        self._model = detection_model
        self._max_detections = max_detections
        self._use_regular_nms = use_regular_nms
        self._Anchors =  AnchorBox(config).get_anchors()

    def _process_config(self, config):
        self._num_classes = config['training_config']['num_classes']
        self._scale_values = {}

    def input_shape(self):
        """Returns shape of TFLite model input."""
        return [1, config["model_config"]["target_height"], config["model_config"]["target_width"], 3]

    def postprocess_implements_signature(self):
        """Returns tf.implements signature for MLIR legalization of TFLite NMS."""
        implements_signature = [
            'name: "%s"' % _DETECTION_POSTPROCESS_FUNC,
            'attr { key: "max_detections" value { i: %d } }' % self._max_detections,
            'attr { key: "max_classes_per_detection" value { i: %d } }' %
            'attr { key: "use_regular_nms" value { b: %s } }' %
            'attr { key: "nms_score_threshold" value { f: %f } }' %
            'attr { key: "nms_iou_threshold" value { f: %f } }' %
            'attr { key: "y_scale" value { f: %f } }' %
            'attr { key: "x_scale" value { f: %f } }' %
            'attr { key: "h_scale" value { f: %f } }' %
            'attr { key: "w_scale" value { f: %f } }' %
            'attr { key: "num_classes" value { i: %d } }' % self._num_classes
        implements_signature = ' '.join(implements_signature)
        return implements_signature

    def _get_postprocess_fn(self, num_anchors, num_classes):
        # There is no TF equivalent for TFLite's custom post-processing op.
        # So we add an 'empty' composite function here, that is legalized to the
        # custom op with MLIR.
        # pylint: disable=g-unused-argument,unused-argument
        def dummy_post_processing(box_encodings, class_predictions, anchors):
            boxes = tf.constant(0.0, dtype=tf.float32, name='boxes')
            scores = tf.constant(0.0, dtype=tf.float32, name='scores')
            classes = tf.constant(0.0, dtype=tf.float32, name='classes')
            num_detections = tf.constant(0.0, dtype=tf.float32, name='num_detections')
            return boxes, classes, scores, num_detections

        return dummy_post_processing

    def inference_fn(self, image):
        """Encapsulates SSD inference for TFLite conversion.

        NOTE: The Args & Returns sections below indicate the TFLite model signature,
        and not what the TF graph does (since the latter does not include the custom
        NMS op used by TFLite)

          image: a float32 tensor of shape [num_anchors, 4] containing the anchor

          num_detections: a float32 scalar denoting number of total detections.
          classes: a float32 tensor denoting class ID for each detection.
          scores: a float32 tensor denoting score for each detection.
          boxes: a float32 tensor denoting coordinates of each detected box.
        predicted_tensors = self._model(image)
        class_predictions = tf.sigmoid(predicted_tensors[..., 4:])
        class_predictions = tf.identity(class_predictions, name='class_predictions')

        box_encodings = tf.identity(predicted_tensors[..., :4], name='box_encodings')

        anchors = tf.identity(self._Anchors, name='anchors')

        # tf.function@ seems to reverse order of inputs, so reverse them here.
        return self._get_postprocess_fn(detection_module._Anchors.shape[0], self._num_classes)(box_encodings, class_predictions, anchors)[::-1]

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Export TFLite")
    parser.add_argument("--path", type=str, default="logs/MobileNetV3_PFH_SSD_320_240")
    parser.add_argument("--gpus", type=str, default="2")
    args = parser.parse_args()
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus

    with open(os.path.join(args.path, "config.json"), "r") as config_file:
        config = json.load(config_file)

    model = ModelBuilder(config = config)

    latest = tf.train.latest_checkpoint(os.path.join(args.path, "weights"))

    detection_module = SSDModule(config, model)
    concrete_function = detection_module.inference_fn.get_concrete_function(tf.TensorSpec(shape=detection_module.input_shape(), dtype=tf.float32, name='image_tensor'))
    tf.saved_model.save(detection_module, "logs/tflite_Test", signatures=concrete_function)

    converter = tf.lite.TFLiteConverter.from_saved_model("logs/tflite_Test", signature_keys=['serving_default'])
    converter.optimizations = [None] 
    converter.target_spec.supported_ops = [

    tflite_model = converter.convert()
    with tf.io.gfile.GFile(os.path.join(args.path, "detect.tflite"), 'wb') as f:
gaikwadrahul8 commented 8 months ago

Hi, @Jove125

I apologize for the delayed response, could you please export TensorFlow Lite model with use_regular_nms=True and use_regular_nms=False and add as zip file format with your code snippet where you're trying to call the model.predict()function with @tensorflow/tfjs-tflite package & complete steps to replicate the same behavior from my end also ?

Thank you for your cooperation and patience.

Jove125 commented 8 months ago

Hi, @gaikwadrahul8

There are 2 tflite's in the attachment and the simplest script to check these models. I noticed that despite the predict freeze (nms_true.tflite model), it continues to load the CPU.


gaikwadrahul8 commented 8 months ago

Hi, @Jove125

Thank you helping with TensorFlow Lite models with use_regular_nms=Trueand use_regular_nms=False and I tried from my end and I'm also observing the same behaviour from my end also with TensorFlow Lite model with use_regular_nms=True so we'll have to dig more into this issue and will update you soon.

Here output log for reference with TensorFlow Lite model with use_regular_nms=False :


Here output log for reference with TensorFlow Lite model with use_regular_nms=True :


Thank you for bringing this issue to our attention, I really appreciate your valuable efforts and time. Thank you for your cooperation