microsoft / onnxruntime

ONNX Runtime: cross-platform, high performance ML inferencing and training accelerator
https://onnxruntime.ai
MIT License
14.66k stars 2.93k forks source link

MaxPool not getting quantized when preceded by Relu #9428

Open Silvan-K opened 3 years ago

Silvan-K commented 3 years ago

Describe the bug MaxPool nodes are not getting quantized if a preceding Relu is not getting quantized.

Urgency

Development of a backend is blocked by this, so it would be great if someone could provide some insights as soon as possible.

System information

To Reproduce

Expected behavior

Would expect a QuantizeLinear node before the Maxpool node in the quantized model.

import torch
import numpy as np
import onnx
import onnxruntime
from onnxruntime import quantization

IMAGE_SHAPE  = (1, 1, 16, 16)
KERNEL_SHAPE = (1, 1,  2,  2)

class ToyModel(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Conv2d(out_channels = KERNEL_SHAPE[0],
                                    in_channels  = KERNEL_SHAPE[1],
                                    kernel_size  = KERNEL_SHAPE[2:],
                                    bias         = False)
        weight = torch.tensor(data  = 127*np.ones(KERNEL_SHAPE).astype("float32"),
                              dtype = torch.float32)
        self.conv.weight = torch.nn.Parameter(weight, requires_grad = False)
        self.max_pool = torch.nn.MaxPool2d(kernel_size = (2,2))

    def forward(self, input):
        return self.max_pool(self.conv(input).relu())

class ToyDataProvider(onnxruntime.quantization.CalibrationDataReader):

    def __init__(self, input_name):
        self.data = ({ input_name: prefac*np.ones(IMAGE_SHAPE).astype("float32") } for prefac in [-127,+127])

    def get_next(self):
        try: return next(self.data)
        except StopIteration: return None

def CreateToyModels(unquantized_path, quantized_path):

    # Save toy model to onnx file
    model = ToyModel()
    torch.onnx.export(model, (torch.empty(IMAGE_SHAPE, dtype=torch.float32)), unquantized_path)
    session = onnxruntime.InferenceSession(unquantized_path)
    input_name = session.get_inputs()[0].name

    # Quantize model with Relu using dummy data
    onnxruntime.quantization.quantize_static(model_input = unquantized_path,
                                             model_output = quantized_path,
                                             calibration_data_reader = ToyDataProvider(input_name),
                                             activation_type = onnxruntime.quantization.QuantType.QUInt8,
                                             weight_type = onnxruntime.quantization.QuantType.QUInt8,
                                             op_types_to_quantize = ["Conv", "MaxPool"],
                                             extra_options = { "WeightSymmetric" : False,
                                                               "ActivationSymmetric" : False })

if __name__ == "__main__":

    # Create quantized and unquantized models
    CreateToyModels(unquantized_path = "model.onnx", quantized_path = "quantized-model.onnx")
stale[bot] commented 2 years ago

This issue has been automatically marked as stale due to inactivity and will be closed in 7 days if no further activity occurs. If further support is needed, please provide an update and/or more details.

Silvan-K commented 2 years ago

@yufenglee, I was just wondering what the the status of this is. Has anyone been able to reproduce the problem?

jl749 commented 1 year ago

hello @Silvan-K

https://github.com/microsoft/onnxruntime/blob/3d7518762ace6929be98e1203174c2dbf1ac094e/onnxruntime/python/tools/quantization/operators/direct_q8.py#L72-L78 it seems like maxpool skips QDQ if previous activation is not quantized

in the most recent version you can set ForceQuantizeNoInputCheck flag True to avoid such behaviour

https://github.com/microsoft/onnxruntime/blob/3d7518762ace6929be98e1203174c2dbf1ac094e/onnxruntime/python/tools/quantization/quantize.py#L302-L305