tensorflow / tflite-support

TFLite Support is a toolkit that helps users to develop ML and deploy TFLite models onto mobile / ioT devices.
Apache License 2.0
371 stars 125 forks source link

Inference time for tflite quantized model is high #980

Open RubensZimbres opened 2 months ago

RubensZimbres commented 2 months ago

I followed a tutorial on MediaPipe and their model, https://storage.googleapis.com/mediapipe-models/image_classifier/efficientnet_lite0/float32/1/efficientnet_lite0.tflite, has inference time of milliseconds.

I used ai-edge-torch to convert a PyTorch efficientnet to tflite, but the inference time is 2-3 seconds.

Here's my code:

efficientnet = torchvision.models.efficientnet_b3(torchvision.models.EfficientNet_B3_Weights.IMAGENET1K_V1).eval()

class PermuteInput(nn.Module):
    def __init__(self):
        super(PermuteInput, self).__init__()

    def forward(self, x):
        # Permute from (batch, height, width, channels) to (batch, channels, height, width)
        return x.permute(0, 3, 1, 2)

import torch.nn.functional as F

class PermuteOutput(nn.Module):
    def __init__(self):
        super(PermuteOutput, self).__init__()

    def forward(self, x):
        return F.normalize(x)

efficientnet_with_reshape = nn.Sequential(
    PermuteInput(),
    efficientnet,
    PermuteOutput()
)

edge_model = efficientnet_with_reshape.eval()

sample_input = (torch.rand((1, 224, 224, 3), dtype=torch.float32),)

edge_model = ai_edge_torch.convert(edge_model.eval(), sample_input)

edge_model.export("/home/user/efficientnet.tflite")

# QUANTIZE TFLITE MODEL

pt2e_quantizer = PT2EQuantizer().set_global(
    get_symmetric_quantization_config(is_per_channel=True, is_dynamic=True)
)

pt2e_torch_model = capture_pre_autograd_graph(efficientnet_with_reshape.eval(),sample_input)
pt2e_torch_model = prepare_pt2e(pt2e_torch_model, pt2e_quantizer)

# Run the prepared model with sample input data to ensure that internal observers are populated with correct values
pt2e_torch_model(*sample_input)

# Convert the prepared model to a quantized model
pt2e_torch_model = convert_pt2e(pt2e_torch_model, fold_quantize=False)

# Convert to an ai_edge_torch model
pt2e_drq_model = ai_edge_torch.convert(pt2e_torch_model, sample_input, quant_config=QuantConfig(pt2e_quantizer=pt2e_quantizer))

pt2e_drq_model.export("/home/user/efficientnet_quantized.tflite")

I properly added metadata to tflite, labels and also added a CORS policy to the bucket.

Is this a quantization issue or a bucket bandwidth issue? Because with the supported model, the inference is really fast.