daquexian / onnx-simplifier

Simplify your onnx model
Apache License 2.0
3.67k stars 377 forks source link

Support conv-bn fold with QDQ node inserted and bn-conv fold #297

Closed tp-nan closed 11 months ago

tp-nan commented 1 year ago

Is it possible to support conv-bn fold with QDQ node inserted and bn-conv fold ?

from pytorch_quantization import quant_modules
quant_modules.initialize()

class PreBN(nn.Module):
    def __init__(self, num_channels):
        super(PreBN, self).__init__()

        self.conv1 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(num_channels)
        self.conv2 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(num_channels)

        self.relu1 = nn.ReLU(inplace=True)
        self.relu2 = nn.ReLU(inplace=True)

        # self.residual_quantizer = quant_nn.TensorQuantizer(quant_nn.QuantConv2d.default_quant_desc_input)

    def forward(self, x):
        out = self.bn1(self.conv1(x))

        out = self.relu1((x)+out)
        out = self.relu2(self.conv2(self.bn2(out)))
        return out
Screenshot 2023-07-03 at 16 16 16
yiliu30 commented 1 year ago

Hi @ShiyangZhang, if you export a torch model into ONNX format with trainingMode.EVAL (by default) mode, it will fuse the conv+bn. torch.onnx.export(..., training=TrainingMode.EVAL,...) You can also refer the related discussion. https://github.com/pytorch/pytorch/issues/49226#issuecomment-764983568

Does the above ONNX model generate by pytorch_quantization?

tp-nan commented 1 year ago

Does the above ONNX model generate by pytorch_quantization?

The above ONNX model generate by torch.onnx.export.

We changed the pytorch model's definition. pytorch_quantization is used for replacing torch.nn.Conv to a quantized version: QuantConv2d, in which QuantizeLinear and DequantizeLinear nodes(torch.fake_quantize_per_tensor_affine) are inserted. So the weight of Conv has been quantized, and also the input.


class QuantConv2d(_QuantConvNd):
    """Quantized 2D conv"""

    default_quant_desc_weight = tensor_quant.QUANT_DESC_8BIT_CONV2D_WEIGHT_PER_CHANNEL

    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 bias=True,
                 padding_mode='zeros',
                 **kwargs):

        kernel_size = _pair(kernel_size)
        stride = _pair(stride)
        padding = _pair(padding)
        dilation = _pair(dilation)

        quant_desc_input, quant_desc_weight = _utils.pop_quant_desc_in_kwargs(self.__class__, **kwargs)
        super(QuantConv2d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, False,
                                          _pair(0), groups, bias, padding_mode,
                                          quant_desc_input=quant_desc_input, quant_desc_weight=quant_desc_weight)

    def forward(self, input):
        # the actual quantization happens in the next level of the class hierarchy
        quant_input, quant_weight = self._quant(input)

        if self.padding_mode == 'circular':
            expanded_padding = ((self.padding[1] + 1) // 2, self.padding[1] // 2,
                                (self.padding[0] + 1) // 2, self.padding[0] // 2)
            output = F.conv2d(F.pad(quant_input, expanded_padding, mode='circular'),
                              quant_weight, self.bias, self.stride,
                              _pair(0), self.dilation, self.groups)
        else:
            output = F.conv2d(quant_input, quant_weight, self.bias, self.stride, self.padding, self.dilation,
                              self.groups)

        return output
yiliu30 commented 1 year ago

I understand. It seems that the torch.onnx conv+bn fusion pass examines the inputs of the conv and merges the BN into Conv if the conv's weight is a constant. The expected pattern should be as follows:

                conv_weight
                    /
                  \              /
            \     /
            [ Conv]
               |
            Conv_out
               |
            [ BN ]
               |
            conv_out_after_bn
                   |
              ...

Once QDQ is inserted, the weight of the conv becomes the output tensor of Dequantize which does not conform to the pattern. The fusion of conv+bn updates the weight based on the BN parameters. However, once QDQ is added, the weight is dependent on Q/QD, which makes sense not to fuse it.

To fuse BN + Conv, how about exporting the torch model into ONNX format first (the conv+bn fusion will be done at this stage), and then do use the ONNX quantization tools to quantize it?

tp-nan commented 1 year ago

To fuse BN + Conv, how about exporting the torch model into ONNX format first (the conv+bn fusion will be done at this stage), and then do use the ONNX quantization tools to quantize it?

Thanks for your reply! That's a good advise for Post Training Quantization. In the context of Quantization Aware Training (QAT), does the scale factor for Batch Normalization can be merged into the scale factor of the quantize and dequantize layers for weight of conv during inference ?