Open annahambi opened 2 years ago
HI @annahambi
Thank you for your interest. I totally understand your idea of trying the standard ResNet-20.
Could you please try our new framework? We just updated the codebase and I change the definition of ResNet for CIFAR dataset, which only use 1 layer in the stem. For your implementation, I also suggest using 16 channels in the first Conv2d.
Thanks agian!
Hi @yhhhli
Thanks for the quick reply! As you suggested, I have pulled the latest version of the SNN_Calibration
repository. I have implemented the original ResNet as shown below. I have used 16 channels in the first Conv2d as you suggested as well.
The ANN training
python -m SNN_Calibration.main_train_cifar --dataset CIFAR10 --arch orgres20 --dataset 'CIFAR10' --usebn
results in Test Accuracy of the model on the 10000 test images: 93.260
.
The problem with the SNN calibration still persists, and
python -m SNN_Calibration.main_cal_cifar --dataset 'CIFAR10' --arch orgres20 --T 32 --calib advanced --usebn \
--dataset 'CIFAR10' --model 'raw/CIFAR10/orgres20_wBN_wd5e4_state_dict_best.pth'
and results in Test Accuracy of the model on the 10000 test images: 3.060
which is worse than a random guess.
Do you have any other ideas what might be the problem in this particular conversion?
Best, Anna
'''
ResNet20 on CIFAR10 with the correct number of parameter (0.27M) as in the original publication [1].
References:
[1] K. He, X. Zhang, S. Ren, and J. Sun. Deep residual learning for image recognition. In CVPR, 2016.
[2] K. He, X. Zhang, S. Ren, and J. Sun. Identity mappings in deep residual networks. In ECCV, 2016.
'''
import torch
import torch.nn as nn
import math
from ...utils import StraightThrough
from ...spiking_layer import SpikeModule, Union
from .resnet import SpikeBasicBlock
def conv3x3(in_planes, out_planes, stride=1):
" 3x3 convolution with padding"
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = BN(planes)
self.relu1 = ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = BN(planes)
self.downsample = downsample
self.stride = stride
self.relu2 = ReLU(inplace=True)
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu1(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu2(out)
return out
class Org_ResNet_Cifar(nn.Module):
def __init__(self, block, layers, num_classes=10, use_bn=True):
super(Org_ResNet_Cifar, self).__init__()
global BN
BN = nn.BatchNorm2d if use_bn else StraightThrough
global ReLU
ReLU = nn.ReLU
self.inplanes = 16
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = BN(16)
self.relu = ReLU(inplace=True)
self.layer1 = self._make_layer(block, 16, layers[0])
self.layer2 = self._make_layer(block, 32, layers[1], stride=2)
self.layer3 = self._make_layer(block, 64, layers[2], stride=2)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
#self.fc = nn.Linear(256 * block.expansion, num_classes)
self.fc = nn.Linear(64 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
n = m.weight.size(1)
m.weight.data.normal_(0, 1.0 / float(n))
m.bias.data.zero_()
def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
BN(planes * block.expansion)
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
def org_resnet20(**kwargs):
model = Org_ResNet_Cifar(BasicBlock, [3, 3, 3], **kwargs)
return model
res_specials = {BasicBlock: SpikeBasicBlock}
if __name__ == '__main__':
net = resnet20()
net.eval()
x = torch.randn(1, 3, 32, 32)
net(x)
Unfortunately, I cannot locate the problem right now.
What I usually debug, is to first check if the code works fine. I would suggest evaluating the ANN module, and then evaluating the SNN module under ANN mode, to see if it has the original accuracy. If they are all good, probably evaluate the accuracy of high time steps like T=512 w/o calibration.
Could you please report these experiments so that we can locate the problem of this issue. Thank you so much.
Thank you for the hints on debugging!
1) Evaluating the ANN module
I have loaded the state_dict
from the pth
file that was created by the code. Indeed, when I validate the model with the images in test_loader
I obtain the >90% accuracy for the ANN that I also noted down from the training result.
2) Evaluating the SNN module in ANN mode
For clarification: When you say evaluating the SNN module under ANN mode do you mean to use snn.set_spike_state(use_spike=False)
? Or what else needs to be done? There is the search_fold_and_remove_bn(ann)
and I am not sure if I need to execute it. The code below gives only 10.2% accuracy in the evaluation.
sim_length = 32
ann.load_state_dict(state_dict, strict=True)
snn = SpikeModel(model=ann, sim_length=sim_length, specials=res_specials)
snn.set_spike_state(use_spike=False)
correct = 0
total = 0
# start testing
snn.eval()
snn.to(device)
with torch.no_grad():
for batch_idx, (inputs, targets) in enumerate(test_loader):
inputs = inputs.to(device)
targets = targets.to(device)
outputs = snn(inputs)
loss = criterion(outputs, targets)
_, predicted = outputs.cpu().max(1)
total += float(targets.size(0))
correct += float(predicted.eq(targets.cpu()).sum().item())
if batch_idx % 100 == 0:
acc = 100. * float(correct) / float(total)
print(batch_idx, len(test_loader), ' Acc: %.5f' % acc)
Hi @yhhhli It would be great to get your feedback on the above : )
Hi Anna,
I noticed your code and results. It is very likely that you didn't use SpikeResModule defined in here. We add it by define the mapping dictionary (here) and pass it to SpikeModel (see this example).
To solve it, I suggest use our defined BasicBlock, SpikeBasicBlock and res_specials for constructing your original ResNet-20.
Dear Yuhang,
I have noticed that you are using a ResNet20 for CIFAR10 with 11.3 Million parameters. In the original ResNet publication of He et al [1] the definition of ResNet20 on CIFAR10 is given and results in 0.27 Million parameters. I know that it is somewhat "conventional" to use the implementation of ResNet20 you are using, the problem is that I am really interested in the one with the smaller number of parameters : P
I have defined the "original" ResNet20 for CIFAR10 with 0.27 M parameters as shown below. I have added the file under
models
in your repository and run first the ANN training and then SNN calibration on it:The ANN training is working well and results in 93.5% accuracy. But for some reason the SNN_Calibration doesn't work on the network below and results in 20% accuracy. Please help to get the SNN Calibration working on this : ) It would be much appreciated to understand the issue here.
[1] He, K., Zhang, X., Ren, S., & Sun, J. (2015). Deep Residual Learning for Image Recognition. Proceedings of the IEEE Computer Society Conference on Computer Vision and Pattern Recognition, 2016-Decem, 770–778. https://doi.org/10.1109/CVPR.2016.90