tensorflow / tflite-support

TFLite Support is a toolkit that helps users to develop ML and deploy TFLite models onto mobile / ioT devices.
Apache License 2.0
378 stars 128 forks source link

Model conversion with ai-edge-torch + Matadata presents long inference time #982

Open RubensZimbres opened 3 months ago

RubensZimbres commented 3 months ago

I developed a customized Resnet in PyTorch, converted with ai-edge-torch to tflite, added metadata and put the tflite model in a Google Cloud bucket for inference, in a MediaPipe example of Image Classification. Explanation and code here.

resnet50 = torchvision.models.resnet50(torchvision.models.ResNet50_Weights.IMAGENET1K_V1).eval()

class PermuteInput(nn.Module):
    def __init__(self):
        super(PermuteInput, self).__init__()

    def forward(self, x):
        # Permute from (batch, height, width, channels) to (batch, channels, height, width)
        return x.permute(0, 3, 1, 2)

class HandleOutput(nn.Module):
    def __init__(self):
        super(HandleOutput, self).__init__()

    def forward(self, x):
        return F.normalize(x) 

# Add the custom reshape layer to the model
# Here, we use a Sequential container to append the reshape layer after the adaptive average pooling layer
resnet50_with_reshape = nn.Sequential(
    PermuteInput(),
    resnet50,
    HandleOutput()
)

edge_model = resnet50_with_reshape.eval()
sample_input = (torch.rand((1, 224, 224, 3), dtype=torch.float32),)

edge_model = ai_edge_torch.convert(edge_model.eval(), sample_input)

pt2e_quantizer = PT2EQuantizer().set_global(
    get_symmetric_quantization_config(is_per_channel=True, is_dynamic=True))

pt2e_torch_model = capture_pre_autograd_graph(resnet50_with_reshape.eval(),sample_input)
pt2e_torch_model = prepare_pt2e(pt2e_torch_model, pt2e_quantizer)

# Run the prepared model with sample input data to ensure that internal observers are populated with correct values
pt2e_torch_model(*sample_input)

# Convert the prepared model to a quantized model
pt2e_torch_model = convert_pt2e(pt2e_torch_model, fold_quantize=False)

# Convert to an ai_edge_torch model
pt2e_drq_model = ai_edge_torch.convert(pt2e_torch_model, sample_input, quant_config=QuantConfig(pt2e_quantizer=pt2e_quantizer))

However, when the default (and supported model) is used, inference time is milliseconds. With the customized tflite model, inference time is 2 seconds.

The customized model was developed and converted in a Tensorflow 2.17.0. However, the tflite library only works successfully wit Tensorflow 2.13.0. I am using the latest version of tflite-support in an Anaconda enviroment on Ubuntu 22.04.

When I add metadata, the only Tensorflow version that runs is 2.13.0. This difference in Tensorflow versions and age is generating an issue: the "Created TensorFlow Lite XNNPACK delegate for CPU." takes time and does not do inference in real-time in the browser, differently from the supported tflite at:

https://storage.googleapis.com/mediapipe-models/image_classifier/efficientnet_lite0/float32/1/efficientnet_lite0.tflite

With my model in the web page, "Graph successfully started running", but inference takes too long.

By using tflite-support 0.4.4 I get this error:

ImportError: generic_type: cannot initialize type "StatusCode": an object with that name is already defined

It only works with tflite-support==0.1.0a1, but then the inference time is too long and I get this error in Debug Console:

W0724 16:08:28.765000 1880752 inference_feedback_manager.cc:121] Feedback manager requires a model with a single signature inference. Disabling support for feedback tensors.

I noticed that I successfuly create a tflite model with a serving_default signature. However, when I add metadata, this signature disappears. Here's the code I'm using to add metadata:

from tflite_support import flatbuffers
from tflite_support import metadata as _metadata
from tflite_support import metadata_schema_py_generated as _metadata_fb
import os
import tensorflow as tf

model_file="efficientnet_quantized.tflite"
os.chdir("/folder")

"""Creates the metadata for an image classifier."""

# Creates model info.
model_meta = _metadata_fb.ModelMetadataT()
model_meta.name = "Resnet image classifier"
model_meta.description = ("Identify the most prominent object in the "
                          "image from a set of 1,000 categories such as "
                          "trees, animals, food, vehicles, person etc.")
model_meta.version = "v1"
model_meta.author = "Rubens Zimbres"
model_meta.license = ("Apache License. Version 2.0 "
                      "http://www.apache.org/licenses/LICENSE-2.0.")

# Creates input info.
input_meta = _metadata_fb.TensorMetadataT()

# Creates output info.
output_meta = _metadata_fb.TensorMetadataT()

input_meta.name = "Image"
input_meta.description = (
    "Input image to be classified. The expected image is {0} x {1}, with "
    "three channels (red, blue, and green) per pixel. Each value in the "
    "tensor is a single byte between 0 and 255.".format(244, 244))
input_meta.content = _metadata_fb.ContentT()
input_meta.content.contentProperties = _metadata_fb.ImagePropertiesT()
input_meta.content.contentProperties.colorSpace = (
    _metadata_fb.ColorSpaceType.RGB)
input_meta.content.contentPropertiesType = (
    _metadata_fb.ContentProperties.ImageProperties)
input_normalization = _metadata_fb.ProcessUnitT()
input_normalization.optionsType = (
    _metadata_fb.ProcessUnitOptions.NormalizationOptions)
input_normalization.options = _metadata_fb.NormalizationOptionsT()
input_normalization.options.mean = [127.5]
input_normalization.options.std = [127.5]
input_meta.processUnits = [input_normalization]
input_stats = _metadata_fb.StatsT()
input_stats.max = [1]
input_stats.min = [0]
input_stats.width = [224]
input_stats.height = [224]
input_stats.num_classes = [1000]
input_meta.stats = input_stats

# Creates output info.
output_meta = _metadata_fb.TensorMetadataT()
output_meta.name = "probability"
output_meta.description = "Probabilities of the 1000 labels respectively."
output_meta.content = _metadata_fb.ContentT()
output_meta.content.content_properties = _metadata_fb.FeaturePropertiesT()
output_meta.content.contentPropertiesType = (
    _metadata_fb.ContentProperties.FeatureProperties)
output_stats = _metadata_fb.StatsT()
output_stats.max = [1.0]
output_stats.min = [0.0]
output_meta.stats = output_stats
label_file = _metadata_fb.AssociatedFileT()
label_file.name = os.path.basename("resnet_labels2.txt")
label_file.description = "Labels for objects that the model can recognize."
label_file.type = _metadata_fb.AssociatedFileType.TENSOR_AXIS_LABELS
output_meta.associatedFiles = [label_file]

# Creates subgraph info.
subgraph = _metadata_fb.SubGraphMetadataT()
subgraph.inputTensorMetadata = [input_meta]
subgraph.outputTensorMetadata = [output_meta]
model_meta.subgraphMetadata = [subgraph]

b = flatbuffers.Builder(0)
b.Finish(
    model_meta.Pack(b),
    _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
metadata_buf = b.Output()

populator = _metadata.MetadataPopulator.with_model_file(model_file)
populator.load_metadata_buffer(metadata_buf)
populator.load_associated_files(["resnet_labels2.txt"])
populator.populate()

displayer = _metadata.MetadataDisplayer.with_model_file('/folder/efficientnet_quantized.tflite')
export_json_file = "/folder/metadata.json"
json_file = displayer.get_metadata_json()
# Optional: write out the metadata as a json file
with open(export_json_file, "w") as f:
  f.write(json_file)

Any ideas no how to solve this issue ?

UPDATE

I found the problem: The tflite quantization is generating another signature and then the web page debugger shows:

Feedback manager requires a model with a single signature inference. Disabling support for feedback tensors.

This makes the XNNPACK to take longer to load and delays inference time