alibaba / TinyNeuralNetwork

TinyNeuralNetwork is an efficient and easy-to-use deep learning model compression framework.
MIT License
738 stars 117 forks source link

Mul + Add quantization #364

Open spacycoder opened 5 days ago

spacycoder commented 5 days ago

Hi, I have quantized a model to int8 and the converter produces this graph: mul_add

Why does it dequantize before the Mul op, is this expected?

spacycoder commented 5 days ago

This reproduces it:

import torch.nn as nn
import torch
from tinynn.graph.quantization.quantizer import PostQuantizer
from tinynn.converter import TFLiteConverter

class Dummy(nn.Module):

    def __init__(self):
        super().__init__()
        self.max_pool = nn.MaxPool2d(kernel_size=4, stride=4)
        self.batch_norm = nn.BatchNorm2d(256)
    def forward(self, x):
        x = self.batch_norm(self.max_pool(x))
        return x

def _main():
    dummy_input0 = torch.rand(1, 256, 60,  60).float()
    model = Dummy()

    ptq_config = {
        "backend": "qnnpack",
        "per_tensor": False,
        "disable_requantization_for_cat": True
    }
    quantizer = PostQuantizer(
        model, (dummy_input0,), work_dir="mul_add_model", config=ptq_config
    )

    ptq_model = quantizer.quantize()
    ptq_model(dummy_input0,)

    with torch.no_grad():
        ptq_model.eval()
        ptq_model.cpu()

        ptq_model = quantizer.convert(ptq_model)
        torch.backends.quantized.engine = quantizer.backend
        converter = TFLiteConverter(
            ptq_model,
            (dummy_input0,),
            "mul_add_model.tflite",
            fuse_quant_dequant=True,
            quantize_target_type="int8"
        )
        converter.convert()

if __name__ == '__main__':
    _main()
spacycoder commented 5 days ago

I get the same with mul -> softmax: mul_softmax

conversion code:

import torch.nn as nn
import torch
from tinynn.graph.quantization.quantizer import PostQuantizer
from tinynn.converter import TFLiteConverter

class Dummy(nn.Module):

    def forward(self, x: torch.Tensor):
        scale_factor = 2.0
        weights = torch.softmax((x * scale_factor), dim=-1)
        return weights

def _main():
    dummy_input0 = torch.rand(1, 8, 225, 225).float()
    model = Dummy()

    ptq_config = {
        "backend": "qnnpack",
        "per_tensor": False,
        "disable_requantization_for_cat": True
    }
    quantizer = PostQuantizer(
        model, (dummy_input0,), work_dir="softmax", config=ptq_config
    )

    ptq_model = quantizer.quantize()
    ptq_model(dummy_input0,)

    with torch.no_grad():
        ptq_model.eval()
        ptq_model.cpu()

        ptq_model = quantizer.convert(ptq_model)
        torch.backends.quantized.engine = quantizer.backend
        converter = TFLiteConverter(
            ptq_model,
            (dummy_input0,),
            "softmax.tflite",
            fuse_quant_dequant=True,
            quantize_target_type="int8",
            rewrite_quantizable=True
        )
        converter.convert()

if __name__ == '__main__':
    _main()