yhhhli / SNN_Calibration

Pytorch Implementation of Spiking Neural Networks Calibration, ICML 2021
MIT License
81 stars 13 forks source link

SNN Calibration of Original ResNet20 #6

Open annahambi opened 2 years ago

annahambi commented 2 years ago

Dear Yuhang,

I have noticed that you are using a ResNet20 for CIFAR10 with 11.3 Million parameters. In the original ResNet publication of He et al [1] the definition of ResNet20 on CIFAR10 is given and results in 0.27 Million parameters. I know that it is somewhat "conventional" to use the implementation of ResNet20 you are using, the problem is that I am really interested in the one with the smaller number of parameters : P

I have defined the "original" ResNet20 for CIFAR10 with 0.27 M parameters as shown below. I have added the file under models in your repository and run first the ANN training and then SNN calibration on it:

python -m SNN_Calibration.CIFAR.main_train --dataset CIFAR10 --arch orgres20 --dpath 'datasets/CIFAR10/' --usebn
python -m SNN_Calibration.CIFAR.main_calibration --dataset CIFAR10 --arch orgres20 --T 16 --usebn --calib advanced --dpath 'datasets/CIFAR10/'

The ANN training is working well and results in 93.5% accuracy. But for some reason the SNN_Calibration doesn't work on the network below and results in 20% accuracy. Please help to get the SNN Calibration working on this : ) It would be much appreciated to understand the issue here.

[1] He, K., Zhang, X., Ren, S., & Sun, J. (2015). Deep Residual Learning for Image Recognition. Proceedings of the IEEE Computer Society Conference on Computer Vision and Pattern Recognition, 2016-Decem, 770–778. https://doi.org/10.1109/CVPR.2016.90

'''
ResNet20 on CIFAR10 with the correct number of parameter (0.27M) as in the original publication [1].

References:
[1] K. He, X. Zhang, S. Ren, and J. Sun. Deep residual learning for image recognition. In CVPR, 2016.
[2] K. He, X. Zhang, S. Ren, and J. Sun. Identity mappings in deep residual networks. In ECCV, 2016.
'''
import torch
import torch.nn as nn
import math
# @anna: I fixed the following relative imports
from ...CIFAR.models.utils import AvgPoolConv, StraightThrough
from ...CIFAR.models.spiking_layer import SpikeModel, SpikeModule, Union
import torch.nn.functional as F
from .resnet import SpikeBasicBlock

def conv3x3(in_planes, out_planes, stride=1):
    " 3x3 convolution with padding"
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = BN(planes)
        self.relu1 = ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = BN(planes)
        self.downsample = downsample
        self.stride = stride
        self.relu2 = ReLU(inplace=True)

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu1(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu2(out)

        return out

class Org_ResNet_Cifar_Modified(nn.Module):

    def __init__(self, block, layers, num_classes=10, use_bn=True):
        super(Org_ResNet_Cifar_Modified, self).__init__()

        global BN
        BN = nn.BatchNorm2d if use_bn else StraightThrough
        global ReLU
        ReLU = nn.ReLU

        self.inplanes = 64
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False),
            BN(64),
            ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False),
            BN(64),
            ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False),
            BN(64),
            ReLU(inplace=True),
        )
        self.layer1 = self._make_layer(block, 16, layers[0], stride=1)
        self.layer2 = self._make_layer(block, 32, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 64, layers[2], stride=2)
        #self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.avgpool = AvgPoolConv(kernel_size=4, stride=1, input_channel=64)
        self.fc_save = nn.Linear(64, num_classes)
        #self.fc = nn.Linear(64, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d) and not isinstance(m, AvgPoolConv):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                n = m.weight.size(1)
                m.weight.data.normal_(0, 1.0 / float(n))
                m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
                BN(planes * block.expansion)
            )
        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        #x = F.avg_pool2d(x, x.size()[3])
        #x = self.layer4(x)
        #print(x.size())
        x = self.avgpool(x)
        print(x.shape)
        x = x.view(x.size(0), -1)
        x = self.fc_save(x)

        return x

def org_resnet20(**kwargs):
    model = Org_ResNet_Cifar_Modified(BasicBlock, [3, 3, 3], **kwargs)
    return model

res_specials = {BasicBlock: SpikeBasicBlock}
yhhhli commented 2 years ago

HI @annahambi

Thank you for your interest. I totally understand your idea of trying the standard ResNet-20.

Could you please try our new framework? We just updated the codebase and I change the definition of ResNet for CIFAR dataset, which only use 1 layer in the stem. For your implementation, I also suggest using 16 channels in the first Conv2d.

Thanks agian!

annahambi commented 2 years ago

Hi @yhhhli

Thanks for the quick reply! As you suggested, I have pulled the latest version of the SNN_Calibration repository. I have implemented the original ResNet as shown below. I have used 16 channels in the first Conv2d as you suggested as well. The ANN training

python -m SNN_Calibration.main_train_cifar --dataset CIFAR10 --arch orgres20 --dataset 'CIFAR10' --usebn 

results in Test Accuracy of the model on the 10000 test images: 93.260.

The problem with the SNN calibration still persists, and

python -m SNN_Calibration.main_cal_cifar --dataset 'CIFAR10' --arch orgres20 --T 32 --calib advanced --usebn \
    --dataset 'CIFAR10' --model 'raw/CIFAR10/orgres20_wBN_wd5e4_state_dict_best.pth'

and results in Test Accuracy of the model on the 10000 test images: 3.060 which is worse than a random guess. Do you have any other ideas what might be the problem in this particular conversion?

Best, Anna

'''
ResNet20 on CIFAR10 with the correct number of parameter (0.27M) as in the original publication [1].

References:
[1] K. He, X. Zhang, S. Ren, and J. Sun. Deep residual learning for image recognition. In CVPR, 2016.
[2] K. He, X. Zhang, S. Ren, and J. Sun. Identity mappings in deep residual networks. In ECCV, 2016.
'''
import torch
import torch.nn as nn
import math
from ...utils import StraightThrough
from ...spiking_layer import SpikeModule, Union
from .resnet import SpikeBasicBlock

def conv3x3(in_planes, out_planes, stride=1):
    " 3x3 convolution with padding"
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = BN(planes)
        self.relu1 = ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = BN(planes)
        self.downsample = downsample
        self.stride = stride
        self.relu2 = ReLU(inplace=True)

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu1(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu2(out)

        return out

class Org_ResNet_Cifar(nn.Module):

    def __init__(self, block, layers, num_classes=10, use_bn=True):
        super(Org_ResNet_Cifar, self).__init__()

        global BN
        BN = nn.BatchNorm2d if use_bn else StraightThrough
        global ReLU
        ReLU = nn.ReLU
        self.inplanes = 16
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = BN(16)
        self.relu = ReLU(inplace=True)
        self.layer1 = self._make_layer(block, 16, layers[0])
        self.layer2 = self._make_layer(block, 32, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 64, layers[2], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        #self.fc = nn.Linear(256 * block.expansion, num_classes)
        self.fc = nn.Linear(64 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                n = m.weight.size(1)
                m.weight.data.normal_(0, 1.0 / float(n))
                m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
                BN(planes * block.expansion)
            )
        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x

def org_resnet20(**kwargs):
    model = Org_ResNet_Cifar(BasicBlock, [3, 3, 3], **kwargs)
    return model

res_specials = {BasicBlock: SpikeBasicBlock}

if __name__ == '__main__':
    net = resnet20()
    net.eval()
    x = torch.randn(1, 3, 32, 32)
    net(x)
yhhhli commented 2 years ago

Unfortunately, I cannot locate the problem right now.

What I usually debug, is to first check if the code works fine. I would suggest evaluating the ANN module, and then evaluating the SNN module under ANN mode, to see if it has the original accuracy. If they are all good, probably evaluate the accuracy of high time steps like T=512 w/o calibration.

Could you please report these experiments so that we can locate the problem of this issue. Thank you so much.

annahambi commented 2 years ago

Thank you for the hints on debugging!

1) Evaluating the ANN module I have loaded the state_dict from the pth file that was created by the code. Indeed, when I validate the model with the images in test_loader I obtain the >90% accuracy for the ANN that I also noted down from the training result.

2) Evaluating the SNN module in ANN mode For clarification: When you say evaluating the SNN module under ANN mode do you mean to use snn.set_spike_state(use_spike=False)? Or what else needs to be done? There is the search_fold_and_remove_bn(ann) and I am not sure if I need to execute it. The code below gives only 10.2% accuracy in the evaluation.

sim_length = 32

ann.load_state_dict(state_dict, strict=True)
snn = SpikeModel(model=ann, sim_length=sim_length, specials=res_specials)
snn.set_spike_state(use_spike=False)

correct = 0
total = 0
# start testing
snn.eval()
snn.to(device)
with torch.no_grad():
    for batch_idx, (inputs, targets) in enumerate(test_loader):
        inputs = inputs.to(device)
        targets = targets.to(device)
        outputs = snn(inputs)
        loss = criterion(outputs, targets)
        _, predicted = outputs.cpu().max(1)
        total += float(targets.size(0))
        correct += float(predicted.eq(targets.cpu()).sum().item())
        if batch_idx % 100 == 0:
            acc = 100. * float(correct) / float(total)
            print(batch_idx, len(test_loader), ' Acc: %.5f' % acc)
annahambi commented 2 years ago

Hi @yhhhli It would be great to get your feedback on the above : )

yhhhli commented 2 years ago

Hi Anna,

I noticed your code and results. It is very likely that you didn't use SpikeResModule defined in here. We add it by define the mapping dictionary (here) and pass it to SpikeModel (see this example).

To solve it, I suggest use our defined BasicBlock, SpikeBasicBlock and res_specials for constructing your original ResNet-20.