Xilinx / finn

Dataflow compiler for QNN inference on FPGAs
https://xilinx.github.io/finn
BSD 3-Clause "New" or "Revised" License
681 stars 218 forks source link

Drop in accuracy of Output between quantized Brevitas model and the streamlined Finn onnx model output. #1088

Open Arbiter-glitch opened 1 month ago

Arbiter-glitch commented 1 month ago

Discussed in https://github.com/Xilinx/finn/discussions/1086

Originally posted by **Arbiter-glitch** May 22, 2024 Iam a Research student working on image processing. While using FINN, I was expecting the accuracy to be maintained between the trained brevitas model and the finn onnx model after streamlining. But even though I am getting the output there is a considerable drop in the image quality metrics. Is it an expected drop in accuracy or is it supposed to be the same. My files have been posted in earlier closed issues. [here](https://github.com/Xilinx/finn/issues/1055) and [here](https://github.com/Xilinx/finn/issues/1060) ### [Update:] I have found that the ouput tensor varies from the first step: qonnx to finn conversion onwards. I dont get the same image quality from onnx execution like how it is in the quantized brevitas software model. Is this normal? Iam having a 5dB drop in PSNR between software and Finn output.

Is it becuase of the quantizations I am using in brevitas are not compatible with FINN?: I took this model file from brevitas examples and modified it accordingly.

As a separate case I also tried using bias quant, since I thought maybe that is needed for accurate output in FINN flow. But, when I exported it to FINN and did the onnx execution, the output were all zeroes, or they were gradually, through the layers becoming zeros. My model files are below. Maybe weights are not accurate after foldquantweights()??

Model.py

import torch
import brevitas.nn as qnn
from brevitas.quant import Int8WeightPerTensorFloat
from brevitas.quant.scaled_int import Int8ActPerTensorFloat
from brevitas.quant.scaled_int import Int8Bias
from brevitas.nn.quant_layer import WeightQuantType

from common import CommonIntAccumulatorAwareWeightQuant
from common import CommonIntWeightPerChannelQuant
from common import CommonIntActQuant
from common import CommonUintActQuant
from common import ConstUint8ActQuant
from common import QuantNearestNeighborConvolution
from torch import nn
import math
IO_DATA_BIT_WIDTH = 8
IO_ACC_BIT_WIDTH = 32

class FSRCNN(nn.Module):
    def __init__(self, scale_factor, num_channels=1, d=16, s=12, m=4,weight_bit_width: int = 8,
            act_bit_width: int = 8,
            acc_bit_width: int = 32,
            weight_quant:WeightQuantType = CommonIntAccumulatorAwareWeightQuant):
        super(FSRCNN, self).__init__()

        self.first_part = nn.Sequential(
            qnn.QuantConv2d(num_channels, d, kernel_size=5, padding=5 // 2, input_quant=CommonIntActQuant,
                            input_bit_width=12,
                            weight_bit_width=weight_bit_width,
                            weight_accumulator_bit_width=acc_bit_width,
                            weight_quant=weight_quant,
                            bias=True),
            nn.ReLU(inplace=True)
        )
        self.mid_part = [qnn.QuantConv2d(d, s, kernel_size=1,input_quant=CommonUintActQuant,
                            input_bit_width=act_bit_width,
                            weight_bit_width=weight_bit_width,
                            weight_accumulator_bit_width=acc_bit_width,
                            weight_quant=weight_quant,
                            bias=True),
                         nn.ReLU(inplace=True)]
        for _ in range(m):
            self.mid_part.extend([qnn.QuantConv2d(s, s, kernel_size=3, padding=3 // 2, input_quant=CommonUintActQuant,
                            input_bit_width=act_bit_width,
                            weight_bit_width=weight_bit_width,
                            weight_accumulator_bit_width=acc_bit_width,
                            weight_quant=weight_quant,
                            bias=True),
                                  nn.ReLU(inplace=True)])
        self.mid_part.extend([qnn.QuantConv2d(s, d, kernel_size=1,input_quant=CommonUintActQuant,
                            input_bit_width=act_bit_width,
                            weight_bit_width=weight_bit_width,
                            weight_accumulator_bit_width=acc_bit_width,
                            weight_quant=weight_quant,
                            bias=True),
                              nn.ReLU(inplace=True)])
        self.mid_part = nn.Sequential(*self.mid_part)

        self.upsample=QuantNearestNeighborConvolution(d,1,kernel_size=3,stride=1,padding=3//2,upscale_factor=scale_factor)
        self.relu = nn.ReLU(inplace=True)
        # Using a QuantReLU here because we need to read out a uint8 image, but FINN
        # requires a ReLU node to precede an unsigned int quant node
        #self.out = qnn.QuantReLU(act_quant=ConstUint8ActQuant, bit_width=IO_DATA_BIT_WIDTH)
        self.out=qnn.QuantIdentity(act_quant=ConstUint8ActQuant, return_quant_tensor=False, bit_width=8)
    def forward(self, x):

        x = self.first_part(x)
        x = self.mid_part(x)

        x = self.upsample(x)
        x = self.relu(x)
        x = self.out(x)
        return x

Common.py


from typing import Optional

from torch import Tensor
import torch.nn as nn

from brevitas.core.restrict_val import FloatRestrictValue
from brevitas.core.restrict_val import RestrictValueType
from brevitas.core.scaling import ScalingImplType
import brevitas.nn as qnn
from brevitas.nn.quant_layer import WeightQuantType
from brevitas.quant import Int8AccumulatorAwareWeightQuant
from brevitas.quant import Int8AccumulatorAwareZeroCenterWeightQuant
from brevitas.quant import Int8ActPerTensorFloat
from brevitas.quant import Int8WeightPerTensorFloat
from brevitas.quant import Uint8ActPerTensorFloat

class CommonIntWeightPerChannelQuant(Int8WeightPerTensorFloat):
    """
    Common per-channel weight quantizer with bit-width set to None so that it's forced to be
    specified by each layer.
    """
    scaling_per_output_channel = True

class CommonIntAccumulatorAwareWeightQuant(Int8AccumulatorAwareWeightQuant):
    """A2Q: Accumulator-Aware Quantization with Guaranteed Overflow Avoidance"""
    restrict_scaling_impl = FloatRestrictValue  # backwards compatibility
    bit_width = None

class CommonIntAccumulatorAwareZeroCenterWeightQuant(Int8AccumulatorAwareZeroCenterWeightQuant):
    """A2Q+: Improving Accumulator-Aware Weight Quantization"""
    bit_width = None

class CommonIntActQuant(Int8ActPerTensorFloat):
    """
    Common signed act quantizer with bit-width set to None so that it's forced to be specified by
    each layer.
    """
    bit_width = None
    restrict_scaling_type = RestrictValueType.LOG_FP

class CommonUintActQuant(Uint8ActPerTensorFloat):
    """Common unsigned act quantizer with bit-width set to None so that it's forced to be
    specified by each layer"""
    bit_width = None
    restrict_scaling_type = RestrictValueType.LOG_FP

class ConstUint8ActQuant(CommonUintActQuant):
    """8-bit unsigned integer activation quantizer with constant unit scaling factor, used
    by the models to quantize outputs into the image space"""
    scaling_impl_type = ScalingImplType.CONST
    scaling_init = 1.

class QuantNearestNeighborConvolution(nn.Module):
    """Quantized nearest neighbor resize convolution"""

    def __init__(
            self,
            in_channels: int,
            out_channels: int,
            kernel_size: Optional[int] = 5,
            stride: Optional[int] = 1,
            padding: Optional[int] = 0,
            upscale_factor: Optional[int] = 2,
            signed_act: Optional[bool] = False,
            bias: Optional[bool] = True,
            weight_quant: WeightQuantType = CommonIntWeightPerChannelQuant,
            acc_bit_width: Optional[int] = 32,
            act_bit_width: Optional[int] = 8,
            weight_bit_width: Optional[int] = 8):
        super().__init__()

        # Using unsigned int activation quantization if the preceding layer has
        # a non-negative range (e.g., following a ReLU activation function)
        act_quant = CommonIntActQuant if signed_act else CommonUintActQuant

        self.upscale_factor = upscale_factor
        # Need to have the quantization node before the nearest neighbor upsampling node
        # for FINN compatibility since the FINN compiler will streamline the quantization
        # node with the preceding monotonic activation function. In the case of ESPCN, this
        # is a ReLU. We need to return the QuantTensor though so that the conv2d is aware
        # of the input bit-width for accumulator-aware quantization (A2Q). For more discussion
        # on this, see https://arxiv.org/abs/2301.13376.
        self.input_quant = qnn.QuantIdentity(
            act_quant=act_quant, return_quant_tensor=True, bit_width=act_bit_width)
        self.interp = qnn.QuantUpsample(scale_factor=upscale_factor)
        self.conv = qnn.QuantConv2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            padding=padding,
            stride=stride,
            bias=bias,
            input_quant=None,
            weight_accumulator_bit_width=acc_bit_width,
            weight_bit_width=weight_bit_width,
            weight_quant=weight_quant)

    def forward(self, inp: Tensor) -> Tensor:
        return self.conv(self.interp(self.input_quant(inp)))