nitin-rathi / hybrid-snn-conversion

Training spiking networks with hybrid ann-snn conversion and spike-based backpropagation
https://openreview.net/forum?id=B1xSperKvH
98 stars 24 forks source link

About snn_vgg11_cifar100 #10

Open yult0821 opened 3 years ago

yult0821 commented 3 years ago

Hi, Firstly I want to thank you guys for making this repository public. Recently I am referring to your code and have some questions.

  1. I have tried to use "snn_vgg11_cifar100.pth" for inference. However, the testing accuracy is very low (0.1592 when T=125). Then I looked into the log file, which said "Loaded SNN model does not have thresholds" (executing model = nn.DataParallel(model) after the if-block). Is this the reason for the very-low accuracy? Does this mean that I need to retrain the SNN model to obtain a higher accuracy? If model = nn.DataParallel(model) was executed before the if-block, the testing accuracy is extreamly low (0.0106 when T=125). And the log file shows "Loaded weight features.0.weight not present in current model". Just for comparison, I also tried "snn_vgg16_cifar10.pth" for inference in CIFAR10, and got similar results as your paper reported.

  2. Further, I also tried to train a SNN model based on the supplied "ann_vgg11_cifar100.pth". When executing model = nn.DataParallel(model) before the if-block, the training process stopped very soon as "Quitting as the training is not progressing", and the log file shows "Error: Loaded weight classifier.6.weight not present in current model". When executing model = nn.DataParallel(model) after the if-block, an error happened ("torch.nn.modules.module.ModuleAttributeError: 'VGG_SNN_STDB' object has no attribute 'module'"), while the log file shows "Success: Loaded classifier.6.weight from ./trained_models/ann/ann_vgg11_cifar100.pth".

Would you please offer some hints or suggestions considering the above observations? By the way, would you kindly offer the download links to the trianed models for ImageNet adopted in your paper? Any help will be appreciated. Thanks again!

yult0821 commented 3 years ago

I found that the main reason of the above observations is the property of the nn.DataParallel(model) command. However, I still got poor snn training performance even based on suitable pretrained ann model, as #3 shows. By the way, would you kindly offer the download links to the trianed models for ImageNet adopted in your paper? Any help will be appreciated. Thanks very much!

nitin-rathi commented 3 years ago

Thank you for posting the issue with pre-trained VGG11. The classifier of pre-trained VGG11 has different configurations which lead to weight mismatch error while loading the model. I have modified the classifier model specifically for VGG11 on CIFAR100 ('vgg_spiking.py') to match the pre-trained model. Now, if you run with 'ann_vgg11_cifar100.pth' it will first find the thresholds for each layer and then execute the snn training. Let me know if you still have any issues. Thank you!

mountains-high commented 2 years ago

Thank you for posting the issue with pre-trained VGG11. The classifier of pre-trained VGG11 has different configurations which lead to weight mismatch error while loading the model. I have modified the classifier model specifically for VGG11 on CIFAR100 ('vgg_spiking.py') to match the pre-trained model. Now, if you run with 'ann_vgg11_cifar100.pth' it will first find the thresholds for each layer and then execute the snn training. Let me know if you still have any issues. Thank you!

Hi, thank you for providing the solution. I tried out, but got this error message:

RuntimeError: Error(s) in loading state_dict for VGG_SNN_STDB:
    size mismatch for classifier.0.weight: copying a param with shape torch.Size([1024, 8192]) from checkpoint, the shape in current model is torch.Size([4096, 2048]).
    size mismatch for classifier.3.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([4096, 4096]).
    size mismatch for classifier.6.weight: copying a param with shape torch.Size([100, 1024]) from checkpoint, the shape in current model is torch.Size([10, 4096]).

This is what I wanted to load pretrained ann_vgg11_cifar100 model. Could you help me to solve this problem. Thank you~