Open RubensZimbres opened 2 months ago
I followed a tutorial on MediaPipe and their model, https://storage.googleapis.com/mediapipe-models/image_classifier/efficientnet_lite0/float32/1/efficientnet_lite0.tflite, has inference time of milliseconds.
I used ai-edge-torch to convert a PyTorch efficientnet to tflite, but the inference time is 2-3 seconds.
Here's my code:
efficientnet = torchvision.models.efficientnet_b3(torchvision.models.EfficientNet_B3_Weights.IMAGENET1K_V1).eval() class PermuteInput(nn.Module): def __init__(self): super(PermuteInput, self).__init__() def forward(self, x): # Permute from (batch, height, width, channels) to (batch, channels, height, width) return x.permute(0, 3, 1, 2) import torch.nn.functional as F class PermuteOutput(nn.Module): def __init__(self): super(PermuteOutput, self).__init__() def forward(self, x): return F.normalize(x) efficientnet_with_reshape = nn.Sequential( PermuteInput(), efficientnet, PermuteOutput() ) edge_model = efficientnet_with_reshape.eval() sample_input = (torch.rand((1, 224, 224, 3), dtype=torch.float32),) edge_model = ai_edge_torch.convert(edge_model.eval(), sample_input) edge_model.export("/home/user/efficientnet.tflite") # QUANTIZE TFLITE MODEL pt2e_quantizer = PT2EQuantizer().set_global( get_symmetric_quantization_config(is_per_channel=True, is_dynamic=True) ) pt2e_torch_model = capture_pre_autograd_graph(efficientnet_with_reshape.eval(),sample_input) pt2e_torch_model = prepare_pt2e(pt2e_torch_model, pt2e_quantizer) # Run the prepared model with sample input data to ensure that internal observers are populated with correct values pt2e_torch_model(*sample_input) # Convert the prepared model to a quantized model pt2e_torch_model = convert_pt2e(pt2e_torch_model, fold_quantize=False) # Convert to an ai_edge_torch model pt2e_drq_model = ai_edge_torch.convert(pt2e_torch_model, sample_input, quant_config=QuantConfig(pt2e_quantizer=pt2e_quantizer)) pt2e_drq_model.export("/home/user/efficientnet_quantized.tflite")
I properly added metadata to tflite, labels and also added a CORS policy to the bucket.
Is this a quantization issue or a bucket bandwidth issue? Because with the supported model, the inference is really fast.
I followed a tutorial on MediaPipe and their model, https://storage.googleapis.com/mediapipe-models/image_classifier/efficientnet_lite0/float32/1/efficientnet_lite0.tflite, has inference time of milliseconds.
I used ai-edge-torch to convert a PyTorch efficientnet to tflite, but the inference time is 2-3 seconds.
Here's my code:
I properly added metadata to tflite, labels and also added a CORS policy to the bucket.
Is this a quantization issue or a bucket bandwidth issue? Because with the supported model, the inference is really fast.