google-ai-edge / ai-edge-torch

Supporting PyTorch models with the Google AI Edge TFLite runtime.
Apache License 2.0
377 stars 51 forks source link

Support for Torchvision Detection Models #182

Closed chadrockey closed 2 months ago

chadrockey commented 2 months ago

Description of the bug:

Possibly related to https://github.com/google-ai-edge/ai-edge-torch/issues/103

Is there a chance that the Torchvision Detection models will be supported? Most of them seem to contain ops that aren't supported and it uses lists as inputs.

Actual vs expected behavior:

import torch
import torchvision
import ai_edge_torch

# Options are
'''
fasterrcnn_mobilenet_v3_large_320_fpn
fasterrcnn_mobilenet_v3_large_fpn
fasterrcnn_resnet50_fpn
fasterrcnn_resnet50_fpn_v2
fcos_resnet50_fpn
keypointrcnn_resnet50_fpn
maskrcnn_resnet50_fpn
maskrcnn_resnet50_fpn_v2
retinanet_resnet50_fpn
retinanet_resnet50_fpn_v2
ssd300_vgg16
ssdlite320_mobilenet_v3_large
'''

detection_model = torchvision.models.get_model("ssd300_vgg16", weights=None, num_classes=2)
sample_inputs = (torch.rand(1, 3, 320, 320),)

# Convert and serialize PyTorch model to a tflite flatbuffer. Note that we
# are setting the model to evaluation mode prior to conversion.
edge_model = ai_edge_torch.convert(detection_model.eval(), sample_inputs)
edge_model.export("detection.tflite")

Any other information you'd like to share?

No response

pkgoogle commented 2 months ago

Hi @chadrockey, so the root cause of most of these issues is lack of torch exportability, you can verify this for yourself by running the following code:

import torch
import torchvision
from ai_edge_torch.debug import find_culprits

# Options are
'''
fasterrcnn_mobilenet_v3_large_320_fpn
fasterrcnn_mobilenet_v3_large_fpn
fasterrcnn_resnet50_fpn
fasterrcnn_resnet50_fpn_v2
fcos_resnet50_fpn
keypointrcnn_resnet50_fpn
maskrcnn_resnet50_fpn
maskrcnn_resnet50_fpn_v2
retinanet_resnet50_fpn
retinanet_resnet50_fpn_v2
ssd300_vgg16
ssdlite320_mobilenet_v3_large
'''

detection_model = torchvision.models.get_model("ssd300_vgg16", weights=None, num_classes=2)
sample_inputs = (torch.rand(1, 3, 320, 320),)

culprits = find_culprits(detection_model, sample_inputs)

culprit = next(culprits)
culprit.print_code()

As such you may need to raise an issue with PyTorch or manually modify the torchvision model to be compliant with Torch.Export. You can find more information here: Debugging & Reporting Errors

github-actions[bot] commented 2 months ago

Marking this issue as stale since it has been open for 7 days with no activity. This issue will be closed if no further activity occurs.

github-actions[bot] commented 2 months ago

This issue was closed because it has been inactive for 14 days. Please post a new issue if you need further assistance. Thanks!