Xilinx / brevitas

Brevitas: neural network quantization in PyTorch
https://xilinx.github.io/brevitas/
Other
1.17k stars 195 forks source link

AttributeError: 'QuantTensor' object has no attribute 'dim' #358

Closed lovodkin93 closed 3 years ago

lovodkin93 commented 3 years ago

Hello, I am trying to use your toolkit in order to perform QAT. Now, unlike the examples in your README file, my model contains Batch Normalization layers. Now, when passing the return_quant_tensor=True parameter to the quantConv2d layers, which are followed by the Batch Normalization layers, I keep getting the following error:

Traceback (most recent call last):
  File "/home/taaviv/dl-quantization/t-aslobodkin/different_quantization_schemes/brevitas/resnet50_brevitas.py", line 227, in <module>
    test()
  File "/home/taaviv/dl-quantization/t-aslobodkin/different_quantization_schemes/brevitas/resnet50_brevitas.py", line 225, in test
    y = net(torch.randn(4, 3, 32, 32).to(device))
  File "/data/anaconda/envs/resnet50_pytorch2/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/taaviv/dl-quantization/t-aslobodkin/different_quantization_schemes/brevitas/resnet50_brevitas.py", line 191, in forward
    x = self.layer1(x)
  File "/data/anaconda/envs/resnet50_pytorch2/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/data/anaconda/envs/resnet50_pytorch2/lib/python3.6/site-packages/torch/nn/modules/container.py", line 100, in forward
    input = module(input)
  File "/data/anaconda/envs/resnet50_pytorch2/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/taaviv/dl-quantization/t-aslobodkin/different_quantization_schemes/brevitas/resnet50_brevitas.py", line 81, in forward
    x = self.bn1(x)
  File "/data/anaconda/envs/resnet50_pytorch2/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/data/anaconda/envs/resnet50_pytorch2/lib/python3.6/site-packages/torch/nn/modules/batchnorm.py", line 85, in forward
    self._check_input_dim(input)
  File "/data/anaconda/envs/resnet50_pytorch2/lib/python3.6/site-packages/torch/nn/modules/batchnorm.py", line 249, in _check_input_dim
    if input.dim() != 4:
AttributeError: 'QuantTensor' object has no attribute 'dim'

It appears the quantized output of the conv2D layers, which is the input to the Batch Norm layers, doesn't have the "dim" attribute.

Here is my code:


# %%
import torch 
from torch import optim
import torch.nn.functional as F
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
from torch import nn as nn
import numpy as np
import datetime as dt
import h5py
import math
import numpy as np
import pandas as pd
import scipy as sci
import matplotlib.pyplot as plt
import os as os
from torch.utils.data import random_split
import brevitas.nn as qnn
from brevitas.quant import Int8WeightPerTensorFloat as SignedWeightQuant
from brevitas.quant import ShiftedUint8WeightPerTensorFloat as UnsignedWeightQuant
from brevitas.quant import ShiftedUint8ActPerTensorFloat as ActQuant
from brevitas.quant import Int8Bias as BiasQuant

# %%
#! conda install -c pytorch torchvision
import torchvision
import torchvision.transforms as transforms

# %%
if torch.cuda.is_available():  
  dev = "cuda:0" 
else:  
  dev = "cpu"  
device = torch.device(dev)  
a = torch.zeros(4,3)    
a = a.to(device)

# %%
bias_quant = True
weight_signed = True
bias_quant   = BiasQuant if bias_quant else None
act_quant    = ActQuant
weight_quant = SignedWeightQuant if weight_signed else UnsignedWeightQuant

# %%
class Block(nn.Module):
    def __init__(self, num_layers , in_channels, out_channels, identity_downsample=None, stride=1):
        assert num_layers in [18, 34, 50, 101, 152], "should be a a valid architecture"
        super(Block, self).__init__()
        self.num_layers = num_layers
        if self.num_layers > 34:
            self.expansion = 4
        else:
            self.expansion = 1
        # ResNet50, 101, and 152 include additional layer of 1x1 kernels
        self.conv1 = qnn.QuantConv2d(in_channels, out_channels, input_quant=act_quant, weight_quant=weight_quant, output_quant=act_quant, bias_quant=bias_quant, kernel_size=1, stride=1, padding=0, return_quant_tensor=True).cuda()

        # self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0).cuda()
        self.bn1 = nn.BatchNorm2d(out_channels).cuda()
        if self.num_layers > 34:
            self.conv2 = qnn.QuantConv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1, input_quant=act_quant, weight_quant=weight_quant, output_quant=act_quant, bias_quant=bias_quant, return_quant_tensor=True).cuda()
            # self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1).cuda()
        else:
            # for ResNet18 and 34, connect input directly to (3x3) kernel (skip first (1x1))
            self.conv2 = qnn.QuantConv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, input_quant=act_quant, weight_quant=weight_quant, output_quant=act_quant, bias_quant=bias_quant, return_quant_tensor=True).cuda()
            # self.conv2 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1).cuda()
        self.bn2 = nn.BatchNorm2d(out_channels).cuda()
        self.conv3 = qnn.QuantConv2d(out_channels, out_channels * self.expansion, kernel_size=1, stride=1, padding=0, input_quant=act_quant, weight_quant=weight_quant, output_quant=act_quant, bias_quant=bias_quant, return_quant_tensor=True).cuda()
        # self.conv3 = nn.Conv2d(out_channels, out_channels * self.expansion, kernel_size=1, stride=1, padding=0).cuda()
        self.bn3 = nn.BatchNorm2d(out_channels * self.expansion).cuda()
        self.relu = nn.ReLU().cuda()
        self.identity_downsample = identity_downsample

    def forward(self, x):
        identity = x
        x = x.to(device)
        identity = x.to(identity)
        if self.num_layers > 34:
            x = self.conv1(x)
            x = self.bn1(x)
            x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        x = self.conv3(x)
        x = self.bn3(x)

        if self.identity_downsample is not None:
            identity = self.identity_downsample(identity)

        x += identity
        x = self.relu(x)
        return x

# %%
class ResNet(nn.Module):
    def __init__(self, num_layers, block, image_channels, num_classes):
        assert num_layers in [18, 34, 50, 101, 152], f'ResNet{num_layers}: Unknown architecture! Number of layers has ' \
                                                     f'to be 18, 34, 50, 101, or 152 '
        super(ResNet, self).__init__()
        if num_layers < 50:
            self.expansion = 1
        else:
            self.expansion = 4
        if num_layers == 18:
            layers = [2, 2, 2, 2]
        elif num_layers == 34 or num_layers == 50:
            layers = [3, 4, 6, 3]
        elif num_layers == 101:
            layers = [3, 4, 23, 3]
        else:
            layers = [3, 8, 36, 3]
        self.in_channels = 64
        self.conv1 = nn.Conv2d(image_channels, 64, kernel_size=7, stride=2, padding=3).cuda()
        self.bn1 = nn.BatchNorm2d(64).cuda()
        self.relu = nn.ReLU().cuda()
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1).cuda()

        # ResNetLayers
        self.layer1 = self.make_layers(num_layers, block, layers[0], intermediate_channels=64, stride=1)
        self.layer2 = self.make_layers(num_layers, block, layers[1], intermediate_channels=128, stride=2)
        self.layer3 = self.make_layers(num_layers, block, layers[2], intermediate_channels=256, stride=2)
        self.layer4 = self.make_layers(num_layers, block, layers[3], intermediate_channels=512, stride=2)

        self.dropout1 = nn.Dropout(p=0.1).cuda()
        self.dropout2 = nn.Dropout(p=0.1).cuda()
        self.dropout3 = nn.Dropout(p=0.1).cuda()
        self.dropout4 = nn.Dropout(p=0.1).cuda()
        self.dropout5 = nn.Dropout(p=0.2).cuda()

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1)).cuda()
        self.fc = nn.Linear(512 * self.expansion, num_classes).cuda()

        #self.criterion = nn.CrossEntropyLoss()
        #self.optimizer = optim.SGD(self.parameters(), lr=0.05, momentum=0.9, weight_decay=0.005) #optim.Adam(self.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.0001, amsgrad=False)
        self.optimizer =optim.Adam(self.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.0005, amsgrad=False)
        #self.scheduler = torch.optim.lr_scheduler.CosineAnnealyingLR(self.optimizer, T_max=200) #optim.lr_scheduler.ExponentialLR(self.optimizer, 0.99) 

        self.criterion = nn.CrossEntropyLoss()
        #self.optimizer = optim.SGD(self.parameters(), lr=0.1,momentum=0.9, weight_decay=5e-4)
        #self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=200)
        self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, mode='min')

        self.loss = None

    def accuracy(self,outputs, labels):
        _, preds = torch.max(outputs, dim=1)
        return torch.tensor(torch.sum(preds == labels).item() / len(preds))

    def training_step(self, batch):
        images, labels = batch 
        images, labels = images.to(device), labels.to(device)
        out = self.forward(images)         # Generate predictions
        loss = self.criterion(out, labels) # Calculate loss
        self.loss = loss
        acc = self.accuracy(out, labels)
        return {'train_loss': loss.detach(), 'train_acc': acc}

    def validation_step(self, batch):
        images, labels = batch 
        images, labels = images.to(device), labels.to(device)
        out = self.forward(images)                  # Generate predictions
        loss = self.criterion(out, labels)   # Calculate loss
        acc = self.accuracy(out, labels)           # Calculate accuracy
        return {'val_loss': loss.detach(), 'val_acc': acc}

    def validation_epoch_end(self, outputs):
        batch_losses = [x['val_loss'] for x in outputs]
        epoch_loss = torch.stack(batch_losses).mean()   # Combine losses
        batch_accs = [x['val_acc'] for x in outputs]
        epoch_acc = torch.stack(batch_accs).mean()      # Combine accuracies
        print(f" val_loss: {epoch_loss.item()}, val_acc: {epoch_acc.item()}")
        return {'val_loss': epoch_loss.item(), 'val_acc': epoch_acc.item()}

    def evaluate(self, val_loader):
        outputs = [self.validation_step(batch) for batch in val_loader]
        return self.validation_epoch_end(outputs)

    def epoch_end(self, epoch, result_val, result_train):
        print(f"Epoch {epoch}, val_loss: {result_val['val_loss']}, val_acc: {result_val['val_acc']}, train_loss: {result_train['train_loss']}, train_acc: {result_train['train_acc']}")   

    def forward(self, x):
        x = x.to(device)
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.dropout1(x)
        x = self.layer2(x)
        x = self.dropout2(x)
        x = self.layer3(x)
        x = self.dropout3(x)
        x = self.layer4(x)
        x = self.dropout5(x)

        x = self.avgpool(x)
        x = x.reshape(x.shape[0], -1)
        x = self.dropout5(x)
        x = self.fc(x)
        return x

    def make_layers(self, num_layers, block, num_residual_blocks, intermediate_channels, stride):
        layers = []

        identity_downsample = nn.Sequential(nn.Conv2d(self.in_channels, intermediate_channels*self.expansion, kernel_size=1, stride=stride),
                                            nn.BatchNorm2d(intermediate_channels*self.expansion))
        layers.append(block(num_layers, self.in_channels, intermediate_channels, identity_downsample, stride))
        self.in_channels = intermediate_channels * self.expansion # 256
        for i in range(num_residual_blocks - 1):
            layers.append(block(num_layers, self.in_channels, intermediate_channels)) # 256 -> 64, 64*4 (256) again
        return nn.Sequential(*layers)

# %%

def ResNet50(img_channels=3, num_classes=10):
    return ResNet(50, Block, img_channels, num_classes)

# %%
def test():
    net = ResNet50(img_channels=3, num_classes=1000).cuda()
    y = net(torch.randn(4, 3, 32, 32).to(device))
    print(y.size())
test()

# %%

transform_train_1 = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_train_2 = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(p=0.3),
    transforms.RandomVerticalFlip(p=0.3),
    transforms.RandomPerspective(distortion_scale=0.6, p=0.25),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

batch_size = 128

trainset_1 = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform_train_1)

trainset_2 = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform_train_2)

trainloader = torch.utils.data.DataLoader(torch.utils.data.ConcatDataset([trainset_1,trainset_2]), batch_size=batch_size,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform_test)

print(len(testset)/2)

testset, valset = random_split(testset, [5000, 5000])

valloader = torch.utils.data.DataLoader(valset, batch_size=batch_size,
                                         shuffle=True, num_workers=2)

testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

# %%
def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

# get some random training images
dataiter = iter(trainloader)
images, labels = dataiter.next()

# show images
imshow(torchvision.utils.make_grid(images))
# print labels
print(' '.join('%5s' % classes[labels[j]] for j in range(batch_size)))

# %%
model_ResNet50 = ResNet50(img_channels=3, num_classes=10)

# %%
def compute_l1_loss(w):
    return torch.abs(w).sum()

def compute_l2_loss(w):  
    return w.dot(w)

# %%
def fit(epochs, model, train_loader, val_loader):

    history_epoch_training = []
    history_epoch_validation = []

    for epoch in range(epochs):
        # Training Phase 
        batch_num=1
        history_batch_training = []
        for batch in train_loader:
            #print(f"Epoch number: {epoch}, Batch number: {batch_num}")
            loss = model.training_step(batch)

            # Zero the gradients
            model.optimizer.zero_grad()

            # Specify L1 and L2 weights
            l1_weight = 1e-12
            l2_weight = 1e-10

            # Compute L1 and L2 loss component
            parameters = []
            for parameter in model.parameters():
                parameters.append(parameter.view(-1))
            l1 = l1_weight * compute_l1_loss(torch.cat(parameters))
            #l2 = l2_weight * compute_l2_loss(torch.cat(parameters))

            # Add L1 and L2 loss components
            model.loss += l1
            #model.loss += l2

            model.loss.backward()
            model.optimizer.step()

            history_batch_training.append(loss)

            #if batch_num%100==0:
            #  print(f"batch_num: {batch_num}, train_loss= {loss['train_loss']/batch_size}, train_acc = : {loss['train_acc']}")

            batch_num+=1

        # Training History 
        losses = np.mean(np.array([x['train_loss'].cpu() for x in history_batch_training]))
        accuracies = np.mean(np.array([x['train_acc'].cpu() for x in history_batch_training]))
        history_epoch_training.append({'train_loss': losses, 'train_acc': accuracies})

        # Validation phase
        result = model.evaluate(val_loader)
        model.epoch_end(epoch, result, history_epoch_training[epoch-1])
        history_epoch_validation.append(result)
        model.scheduler.step(result['val_loss'])

    return history_epoch_training, history_epoch_validation

# %%
model_ResNet50.cuda()

# %%
history_training, history_validation = fit(epochs=1, model=model_ResNet50, train_loader=trainloader, val_loader=valloader)

Could you please help me as to how I can circumvent this problem? Thank you!

bferrarini commented 2 years ago

Hi,

the batchnorm layer calls the method dim() of QuantTensor, which is implemented in the very latest version on Github. Please check the line 284 at: https://github.com/Xilinx/brevitas/blob/master/src/brevitas/quant_tensor/init.py

I believe you can solve the problem by installing Brivitas from GitHub: pip install git+https://github.com/Xilinx/brevitas.git

or adding the dim() method to your currently installed Brevitas:

def dim(self):
         return self.value.dim()