Xilinx / brevitas

Brevitas: neural network quantization in PyTorch
https://xilinx.github.io/brevitas/
Other
1.14k stars 191 forks source link

Example: How to use merge_bn correctly #542

Open g12bftd opened 1 year ago

g12bftd commented 1 year ago

There is an architecture I would like to quantise and retrain from its floating point counterpart. I would like to incorporate the merge_bn operation supported by Brevitas. How exactly would I do this here. An overview is good but some code would be better. Note I only want to merge/fuse the Conv + BN + ReLU components. Here is my architecture:

class QuantizedModel(nn.Module): def init(self, config): super(QuantizedVGG, self).init()

    self.weight_config = config

    k = 1
    self.quant_inp = qnn.QuantIdentity(
        bit_width=16, return_quant_tensor=True)
    self.conv1 = qnn.QuantConv2d(in_channels=3, out_channels=int(k * 128), kernel_size=3, padding=1, weight_bit_width=self.weight_config[0], return_quant_tensor=True, bias=True)
    self.bn1 = nn.BatchNorm2d(int(k * 128))
    self.relu1 = qnn.QuantReLU(bit_width=self.weight_config[0], return_quant_tensor=True)
    self.conv2 = qnn.QuantConv2d(int(k * 128), int(k * 128), kernel_size=3, padding=1, weight_bit_width=self.weight_config[1], return_quant_tensor=True, bias=True)
    self.bn2 = nn.BatchNorm2d(int(k * 128))
    self.relu2 = qnn.QuantReLU(bit_width=self.weight_config[1], return_quant_tensor=True)
    self.max_pool1 = qnn.QuantMaxPool2d(kernel_size=2, stride=2, return_quant_tensor=True)
    self.conv3 = qnn.QuantConv2d(int(k * 128), int(k * 256), kernel_size=3, padding=1, weight_bit_width=self.weight_config[2], return_quant_tensor=True, bias=True)
    self.bn3 = nn.BatchNorm2d(int(k * 256))
    self.relu3 = qnn.QuantReLU(bit_width=self.weight_config[2], return_quant_tensor=True)
    self.conv4 = qnn.QuantConv2d(int(k * 256), int(k * 256), kernel_size=3, padding=1, weight_bit_width=self.weight_config[3], return_quant_tensor=True, bias=True)
    self.bn4 = nn.BatchNorm2d(int(k * 256))
    self.relu4 = qnn.QuantReLU(bit_width=self.weight_config[3], return_quant_tensor=True)
    self.max_pool2 = qnn.QuantMaxPool2d(kernel_size=2, stride=2, return_quant_tensor=True)
    self.conv5 = qnn.QuantConv2d(int(k * 256), int(k * 512), kernel_size=3, padding=1, weight_bit_width=self.weight_config[4], return_quant_tensor=True, bias=True)
    self.bn5 = nn.BatchNorm2d(int(k * 512))
    self.relu5 = qnn.QuantReLU(bit_width=self.weight_config[4], return_quant_tensor=True)
    self.conv6 = qnn.QuantConv2d(int(k * 512), int(k * 512), kernel_size=3, padding=1, weight_bit_width=self.weight_config[5], return_quant_tensor=True, bias=True)
    self.bn6 = nn.BatchNorm2d(int(k * 512))
    self.relu6 = qnn.QuantReLU(bit_width=self.weight_config[5], return_quant_tensor=True)
    self.max_pool3 = qnn.QuantMaxPool2d(kernel_size=2, stride=2, return_quant_tensor=True)

    input_feats = 8192

    self.fc1 = qnn.QuantLinear(input_feats, int(k * 1024), weight_bit_width=self.weight_config[6], return_quant_tensor=True, bias=True)
    self.fc2 = qnn.QuantLinear(int(k * 1024), 10, weight_bit_width=self.weight_config[7], bias=True)

def forward(self, x):
    out = self.relu1(self.bn1(self.conv1(x)))
    out = self.relu2(self.bn2(self.conv2(out)))
    out = self.max_pool1(out)
    out = self.relu3(self.bn3(self.conv3(out)))
    out = self.relu4(self.bn4(self.conv4(out)))
    out = self.max_pool2(out)
    out = self.relu5(self.bn5(self.conv5(out)))
    out = self.relu6(self.bn6(self.conv6(out)))
    out = self.max_pool3(out)
    out = out.reshape(out.shape[0], -1)
    out = self.fc1(out)
    out = self.fc2(out)
    return out
MohamedA95 commented 1 year ago

Hi @g12bftd, As far as I understand merging batch normalization layers is usually a post-training optimization so I would train the model and then create a script that defines two objects of the same model one with the batch norm and one without. Then I would loop over the model with batch norm merging conv & batch_norm layers and then saving the results in the new model --the one with no batch norm layer-- so the code should look roughly like follows:

bn_model = QuantizedModel(bn=True)
model = QuantizedModel(bn=False)
for l in bn_model:
    if l isinstanceof(qnn.quantconv):
        merge_bn(l,nextlayer)
        model[index_of_corresponding_layer].copy_state_dict(l)
torch.save(model, fused_QuantizedModel.pth)

Also, I would recommend defining the model using nn.sequential to make it easier to loop over the model

wilfredkisku commented 1 year ago

@MohamedA95 I am new to brevitas so is it that we need to train with the classical BN layers? If you could elaborate to a newbie such as me, as all the models that I am trying to export need to have intermediate BN layers.

MohamedA95 commented 1 year ago

Hi @wilfredkisku, What do you mean by classical BN layers? do you mean torch.nn.BatchNorm2d if this is what you mean then if your model requires it you will have to use it. AFAIK brevitas does not have a quant bn yet. do you can use torch bn and then fuse it with the previous conv as both layers are linear transformations check this link

wilfredkisku commented 1 year ago

@MohamedA95 Thank you for the reply. Yes, the models that I am using requires torch.nn.BatchNorm2d layers. Can they also be fused with quantized layers? Thanks again.

MohamedA95 commented 1 year ago

Yes they can be fused brevitas even has a function to do it under brevitas.nn.utils

wilfredkisku commented 1 year ago

@MohamedA95 Thanks for all the help. I have been able to understand the idea behind fusing the layers. What I have done now is create two models that are similar but one with CONV + BN and the other without BN.

###########################
#### MODEL 1 ##############
###########################

from torch.nn import Module
import torch.nn.functional as F

import torch.nn as nn

import brevitas.nn as qnn
from brevitas.quant import Int8Bias as BiasQuant

class QuantWeightActLeNet(Module):
    def __init__(self):
        super(QuantWeightActLeNet, self).__init__()
        self.quant_inp = qnn.QuantIdentity(bit_width=4)
        self.conv1 = qnn.QuantConv2d(3, 6, 5, bias=True, weight_bit_width=4)
        self.relu1 = qnn.QuantReLU(bit_width=4)
        self.bn = nn.BatchNorm2d(6)

    def forward(self, x):
        out = self.quant_inp(x)
        out = self.relu1(self.bn(self.conv1(out)))
        return out

###########################
#### MODEL 2 ##############
###########################
class QuantWeightActLeNet_wo(Module):
    def __init__(self):
        super(QuantWeightActLeNet_wo, self).__init__()
        self.quant_inp = qnn.QuantIdentity(bit_width=4)
        self.conv1 = qnn.QuantConv2d(3, 6, 5, bias=True, weight_bit_width=4)
        self.relu1 = qnn.QuantReLU(bit_width=4)

    def forward(self, x):
        out = self.quant_inp(x)
        out = self.relu1(self.conv1(out))
        return out

quant_weight_act_lenet_wo = QuantWeightActLeNet_wo()
quant_weight_act_lenet = QuantWeightActLeNet()

I am using the merge_bnfunctions to merge the CONV and BN layer:

#######################
###### MERGE ##########
#######################

def merge_bn(layer, bn, output_channel_dim=0):
    out = mul_add_from_bn(
        bn_mean=bn.running_mean,
        bn_var=bn.running_var,
        bn_eps=bn.eps,
        bn_weight=bn.weight.data.clone(),
        bn_bias=bn.bias.data.clone())
    mul_factor, add_factor = out

    #compute the shape of the channel
    out_ch_weight_shape = compute_channel_view_shape(layer.weight, output_channel_dim)

    #in-place operations multiply the layer weights with the mul_factor of the BN
    #without making a new copy of the Tensor
    layer.weight.data.mul_(mul_factor.view(out_ch_weight_shape))

    #handle if -> bias = True
    if layer.bias is not None:
        out_ch_bias_shape = compute_channel_view_shape(layer.bias, channel_dim=0)
        layer.bias.data.mul_(mul_factor.view(out_ch_bias_shape))
        layer.bias.data.add_(add_factor.view(out_ch_bias_shape))
    else:
        layer.bias = Parameter(add_factor)
    if (hasattr(layer, 'weight_quant') and
            isinstance(layer.weight_quant, WeightQuantProxyFromInjector)):
        layer.weight_quant.init_tensor_quant()
    if (hasattr(layer, 'bias_quant') and isinstance(layer.bias_quant, BiasQuantProxyFromInjector)):
        layer.bias_quant.init_tensor_quant()

But I am having issues while copying the trained weights and biases + additional quantization parameres that are present in the Quantization layers such as QuantConv2d. If I use a concize code like the one below for creating the dictionary of weights for only CONV and skipping BN (which has been fused with the CONV earlier.

for keys in pretrained_dict.keys():
  if keys.split('.')[0] != 'bn':
    processed_dict[keys] = pretrained_dict[keys]

quant_weight_act_lenet_wo.load_state_dict(processed_dict, strict=False)

I am able to copy the weights but a few parameters associated with the brevitasquantization library do not get copied. The error is given below:

_IncompatibleKeys(missing_keys=['quant_inp.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value', 'relu1.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value'], unexpected_keys=[])

I would be thankful for any help in this regard. Thanks again.

MohamedA95 commented 1 year ago

Hi @wilfredkisku, I am not sure about your way of copying the state dict, I would do something like the following: 1-Define the two models one with the batch norm and one without 2-Loop over the model with bn fusing it with conv 3-Loop over the model without batch norm copying the state dict from the other model quant_weight_act_lenet_wo.conv1.load_state_dict(quant_weight_act_lenet.conv1.state_dict())

DDDDDY1 commented 9 months ago

Hi @wilfredkisku, What do you mean by classical BN layers? do you mean torch.nn.BatchNorm2d if this is what you mean then if your model requires it you will have to use it. AFAIK brevitas does not have a quant bn yet. do you can use torch bn and then fuse it with the previous conv as both layers are linear transformations check this link

Hi, this reply mentioned that we are able to train fixed-point batchnorm using BatchNorm2dToQuantScaleBias with power of two scale factors. I think it is supporting quant bn?

However it is not clear to me that how it is done for batchnorm? Does the scale and bias change during training? Or it is indeed doing post training for batch norm?