sony / model_optimization

Model Compression Toolkit (MCT) is an open source project for neural network model optimization under efficient, constrained hardware. This project provides researchers, developers, and engineers advanced quantization and compression tools for deploying state-of-the-art neural networks.
https://sony.github.io/model_optimization/
Apache License 2.0
331 stars 53 forks source link

quantize object detection model from PyTorch #1261

Closed Saalsifis closed 6 days ago

Saalsifis commented 2 weeks ago

Issue Type

Bug

Source

source

MCT Version

2.2.1

OS Platform and Distribution

WSL Ubuntu 24.04.1 LTS

Python version

3.12.3

Describe the issue

The mct.ptq.pytorch_post_training_quantization function always crash with the error "fx error: Proxy object cannot be iterated". I am trying to quantize an object detection model trained with PyTorch for a camera with a IMX500.

It seems torch.fx can't trace a model with proxy, I tried different model type like yolov8s, ssdlite320_mobilenet_v3_large and FasterRCNN_ResNet50_FPN_V2 but it's always the same result.

Expected behaviour

Quantize the model and export it as .onnx file

Code to reproduce the issue

import torch
import torchvision
from torchvision.models.detection import ssdlite320_mobilenet_v3_large
from torchvision.models.detection import SSDLite320_MobileNet_V3_Large_Weights
from model_compression_toolkit.core import QuantizationErrorMethod
from torchvision.transforms import functional as F
from torchvision.datasets import ImageNet
from torch.utils.data import DataLoader
from torchvision import transforms
import numpy as np
import random
import model_compression_toolkit as mct
import sys

batch_size = 16
n_iter = 10

def get_transform(train):
    transforms = []
    transforms.append(torchvision.transforms.Resize((320, 320)))
    transforms.append(F.to_tensor)  # Convertit l'image en Tensor ici
    if train:
        transforms.append(torchvision.transforms.RandomHorizontalFlip(0.5))
    return torchvision.transforms.Compose(transforms)

dataset = ImageNet(root='./imagenet', split='val', transform=get_transform(False))

dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

def representative_dataset_gen():
    dataloader_iter = iter(dataloader)
    for _ in range(n_iter):
        yield [next(dataloader_iter)[0]]

tcp = mct.get_target_platform_capabilities("pytorch", 'imx500', target_platform_version='v1')

model = ssdlite320_mobilenet_v3_large(num_classes = 10)
model.load_state_dict(torch.load('model.pth', weights_only=True),strict=False)
model.eval()

q_config = mct.core.QuantizationConfig(activation_error_method=QuantizationErrorMethod.MSE,
                                       weights_error_method=QuantizationErrorMethod.MSE,
                                       weights_bias_correction=True,
                                       shift_negative_activation_correction=True,
                                       z_threshold=16)

ptq_config = mct.core.CoreConfig(quantization_config=q_config)

quantized_model, quantization_info = mct.ptq.pytorch_post_training_quantization(
        in_module=model,
        representative_data_gen=representative_dataset_gen,
        core_config=ptq_config,
        target_platform_capabilities=tcp
)

mct.exporter.pytorch_export_model(quantized_model, save_model_path='qmodel.onnx', repr_dataset=representative_dataset_gen)

Log output

CRITICAL:Model Compression Toolkit:Error parsing model with torch.fx fx error: Proxy object cannot be iterated. This can be attempted when the Proxy is used in a loop or as a *args or kwargs function argument. See the torch.fx docs on pytorch.org for a more detailed explanation of what types of control flow can be traced, and check out the Proxy docstring for help troubleshooting Proxy iteration errors Traceback (most recent call last): File "/home/oui/trainer/lib/python3.12/site-packages/model_compression_toolkit/core/pytorch/reader/reader.py", line 90, in fx_graph_module_generation symbolic_traced = symbolic_trace(pytorch_model) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/oui/trainer/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 1281, in symbolic_trace graph = tracer.trace(root, concrete_args) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/oui/trainer/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 823, in trace (self.create_arg(fn(args)),), ^^^^^^^^^ File "/home/oui/trainer/lib/python3.12/site-packages/torchvision/models/detection/ssd.py", line 345, in forward for img in images: File "/home/oui/trainer/lib/python3.12/site-packages/torch/fx/proxy.py", line 456, in iter return self.tracer.iter(self) ^^^^^^^^^^^^^^^^^^^^^^ File "/home/oui/trainer/lib/python3.12/site-packages/torch/fx/proxy.py", line 327, in iter raise TraceError('Proxy object cannot be iterated. This can be ' torch.fx.proxy.TraceError: Proxy object cannot be iterated. This can be attempted when the Proxy is used in a loop or as a args or kwargs function argument. See the torch.fx docs on pytorch.org for a more detailed explanation of what types of control flow can be traced, and check out the Proxy docstring for help troubleshooting Proxy iteration errors During handling of the above exception, another exception occurred: Traceback (most recent call last): File "/home/oui/quantizeMTC.py", line 50, in quantized_model, quantization_info = mct.ptq.pytorch_post_training_quantization( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/oui/trainer/lib/python3.12/site-packages/model_compression_toolkit/ptq/pytorch/quantization_facade.py", line 111, in pytorch_post_training_quantization tg, bit_widthsconfig, , scheduling_info = core_runner(in_model=in_module, ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/oui/trainer/lib/python3.12/site-packages/model_compression_toolkit/core/runner.py", line 114, in core_runner graph = graph_preparation_runner(in_model, ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/oui/trainer/lib/python3.12/site-packages/model_compression_toolkit/core/graph_prep_runner.py", line 72, in graph_preparation_runner graph = read_model_to_graph(in_model, ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/oui/trainer/lib/python3.12/site-packages/model_compression_toolkit/core/graph_prep_runner.py", line 207, in read_model_to_graph graph = fw_impl.model_reader(in_model, ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/oui/trainer/lib/python3.12/site-packages/model_compression_toolkit/core/pytorch/pytorch_implementation.py", line 149, in model_reader return model_reader(_module, representative_data_gen, self.to_numpy, self.to_tensor) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/oui/trainer/lib/python3.12/site-packages/model_compression_toolkit/core/pytorch/reader/reader.py", line 153, in model_reader fx_model = fx_graph_module_generation(model, representative_data_gen, to_tensor) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/oui/trainer/lib/python3.12/site-packages/model_compression_toolkit/core/pytorch/reader/reader.py", line 92, in fx_graph_module_generation Logger.critical(f'Error parsing model with torch.fx\n' File "/home/oui/trainer/lib/python3.12/site-packages/model_compression_toolkit/logger.py", line 117, in critical raise Exception(msg) Exception: Error parsing model with torch.fx fx error: Proxy object cannot be iterated. This can be attempted when the Proxy is used in a loop or as a *args or **kwargs function argument. See the torch.fx docs on pytorch.org for a more detailed explanation of what types of control flow can be traced, and check out the Proxy docstring for help troubleshooting Proxy iteration errors

Idan-BenAmi commented 2 weeks ago

Hi @Saalsifis , Thank you for your feedback.

The error you're encountering originates from torch.fx, which is the initial step MCT uses to simplify PyTorch models. See the documentation here: torch.fx. To use MCT, your model must be compatible with torch.fx.

We're currently preparing a tutorial that will help you updating your model for torch.fx compatibility, addressing common issues we've encountered. We’ll notify you as soon as the tutorial is available.

In the meantime, I’d like to share that we're collaborating with Ultralytics to simplify model export, offering a much easier way to export Ultralytics models (like yolov8n) to IMX500. For more information, please visit IMX500 Export for Ultralytics YOLOv8 - Ultralytics YOLO Docs

Thanks Idan

Saalsifis commented 2 weeks ago

Thanks, I fail to find a list of objet detection model compatible with torch.fx but I will check the collab with Ultralytics.