fangwei123456 / spikingjelly

SpikingJelly is an open-source deep learning framework for Spiking Neural Network (SNN) based on PyTorch.
https://spikingjelly.readthedocs.io
Other
1.25k stars 236 forks source link

SNN accuacy is less as compared to ANN accuracy #259

Open sauravtii opened 1 year ago

sauravtii commented 1 year ago

I am trying out this code (https://github.com/fangwei123456/spikingjelly/blob/master/spikingjelly/activation_based/ann2snn/examples/cnn_mnist.py). The only difference is I have changed the CNN architecture to VGG-16 and I am using CIFAR-10 dataset. Initially I trained it for 50 epochs and the ANN accuray was 94 % and SNN accuracy using Maxnorm was 71 % and 76 % by using Robustnorm. I also tried training it for 100 epochs but the accuracies was good for ANN but similar as before for SNN. Can anyone please tell me how I can I get a good accuracy for SNN as well ?

DingJianhao commented 1 year ago

hi there, you can increase the timesteps to some bigger number like 500. MAXNORM and ROBUSTNORM are very primary methods yet useful. A small timestep may cause decrease in accuracy. Good luck to you!

sauravtii commented 1 year ago

Hi, I tried with 500 timesteps but still the accuracy is not good. I have provided the output below.

Epoch: 99
100%|██████████| 782/782 [00:07<00:00, 107.48it/s]
Validating Accuracy: 0.972

100%|██████████| 200/200 [00:01<00:00, 127.64it/s]
ANN Validating Accuracy: 0.8420
---------------------------------------------
Converting using MaxNorm
100%|██████████| 782/782 [00:07<00:00, 107.32it/s]
Simulating...
100%|██████████| 200/200 [05:29<00:00,  1.65s/it]
SNN accuracy (simulation 500 time-steps): 0.6975
---------------------------------------------
Converting using RobustNorm
100%|██████████| 782/782 [01:50<00:00,  7.10it/s]
Simulating...
100%|██████████| 200/200 [05:27<00:00,  1.64s/it]
SNN accuracy (simulation 500 time-steps): 0.6452
---------------------------------------------
Converting using 1/2 max(activation) as scales...
100%|██████████| 782/782 [00:07<00:00, 106.99it/s]
Simulating...
100%|██████████| 200/200 [05:28<00:00,  1.64s/it]
SNN accuracy (simulation 500 time-steps): 0.7526
---------------------------------------------
Converting using 1/3 max(activation) as scales
100%|██████████| 782/782 [00:07<00:00, 106.80it/s]
Simulating...
100%|██████████| 200/200 [05:28<00:00,  1.64s/it]
SNN accuracy (simulation 500 time-steps): 0.7735
---------------------------------------------
Converting using 1/4 max(activation) as scales
100%|██████████| 782/782 [00:07<00:00, 107.23it/s]
Simulating...
100%|██████████| 200/200 [05:28<00:00,  1.64s/it]
SNN accuracy (simulation 500 time-steps): 0.7336
---------------------------------------------
Converting using 1/5 max(activation) as scales
100%|██████████| 782/782 [00:07<00:00, 107.44it/s]
Simulating...
100%|██████████| 200/200 [05:51<00:00,  1.76s/it]
SNN accuracy (simulation 500 time-steps): 0.6390
DingJianhao commented 1 year ago

well, VGG-16, which pooling method do you choose?

sauravtii commented 1 year ago

Hi, I am using max pooling

fangwei123456 commented 1 year ago

Try to use the average pooling.

hanzh0816 commented 1 year ago

could you please share your code, I'm having the same problem as you

sauravtii commented 1 year ago

could you please share your code, I'm having the same problem as you

My VGG-16 architecture:

class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.network = nn.Sequential(
            nn.Conv2d(3, 64, 3, 1, 1),     # conv-1
            nn.BatchNorm2d(64),
            nn.ReLU(),

            nn.Conv2d(64, 64, 3, 1, 1),    # conv-2
            nn.BatchNorm2d(64),
            nn.ReLU(),
            # nn.MaxPool2d(2, 2),                        
            nn.AvgPool2d(2, 2),

            nn.Conv2d(64, 128, 3, 1, 1),   # conv-3
            nn.BatchNorm2d(128),
            nn.ReLU(),

            nn.Conv2d(128, 128, 3, 1, 1),   # conv-4
            nn.BatchNorm2d(128),
            nn.ReLU(),
            # nn.MaxPool2d(2, 2),                        
            nn.AvgPool2d(2, 2),

            nn.Conv2d(128, 256, 3, 1, 1),   # conv-5
            nn.BatchNorm2d(256),
            nn.ReLU(),

            nn.Conv2d(256, 256, 3, 1, 1),   # conv-6
            nn.BatchNorm2d(256),
            nn.ReLU(),

            nn.Conv2d(256, 256, 3, 1, 1),   # conv-7
            nn.BatchNorm2d(256),
            nn.ReLU(),
            # nn.MaxPool2d(2, 2),                        
            nn.AvgPool2d(2, 2),

            nn.Conv2d(256, 512, 3, 1, 1),   # conv-8
            nn.BatchNorm2d(512),
            nn.ReLU(),

            nn.Conv2d(512, 512, 3, 1, 1),   # conv-9
            nn.BatchNorm2d(512),
            nn.ReLU(),

            nn.Conv2d(512, 512, 3, 1, 1),   # conv-10
            nn.BatchNorm2d(512),
            nn.ReLU(),
            # nn.MaxPool2d(2, 2),                        
            nn.AvgPool2d(2, 2),

            nn.Conv2d(512, 512, 3, 1, 1),   # conv-11
            nn.BatchNorm2d(512),
            nn.ReLU(),

            nn.Conv2d(512, 512, 3, 1, 1),   # conv-12
            nn.BatchNorm2d(512),
            nn.ReLU(),

            nn.Conv2d(512, 512, 3, 1, 1),   # conv-13
            nn.BatchNorm2d(512),
            nn.ReLU(),
            # nn.MaxPool2d(2, 2),            
            nn.AvgPool2d(2, 2),

            nn.Flatten(),

            nn.Linear(1*1*512, 4096),

            nn.Linear(4096, 4096),

            nn.Linear(4096, 10)

        )

    def forward(self,x):
        x = self.network(x)
        return x
DingJianhao commented 1 year ago

great,change to avgpool and your problem may be solved! @sauravtii

sauravtii commented 1 year ago

Thanks @DingJianhao. The accuracy did improve. But is there any way that I can get more than 90 % accuracy ? I have provided my output:

Epoch: 149
100%|██████████| 391/391 [00:06<00:00, 60.81it/s]
Validating Accuracy: 0.979

100%|██████████| 200/200 [00:01<00:00, 134.99it/s]
ANN Validating Accuracy: 0.8489
---------------------------------------------
Converting using MaxNorm
100%|██████████| 391/391 [00:06<00:00, 64.59it/s]
Simulating...
100%|██████████| 200/200 [05:22<00:00,  1.61s/it]
SNN accuracy (simulation 500 time-steps): 0.8379
---------------------------------------------
Converting using RobustNorm
100%|██████████| 391/391 [02:05<00:00,  3.12it/s]
Simulating...
100%|██████████| 200/200 [05:23<00:00,  1.62s/it]
SNN accuracy (simulation 500 time-steps): 0.8060
---------------------------------------------
Converting using 1/2 max(activation) as scales...
100%|██████████| 391/391 [00:06<00:00, 64.59it/s]
Simulating...
100%|██████████| 200/200 [05:22<00:00,  1.61s/it]
SNN accuracy (simulation 500 time-steps): 0.8433
---------------------------------------------
Converting using 1/3 max(activation) as scales
100%|██████████| 391/391 [00:06<00:00, 65.04it/s]
Simulating...
100%|██████████| 200/200 [05:24<00:00,  1.62s/it]
SNN accuracy (simulation 500 time-steps): 0.8300
---------------------------------------------
Converting using 1/4 max(activation) as scales
100%|██████████| 391/391 [00:05<00:00, 65.21it/s]
Simulating...
100%|██████████| 200/200 [05:23<00:00,  1.62s/it]
SNN accuracy (simulation 500 time-steps): 0.7332
---------------------------------------------
Converting using 1/5 max(activation) as scales
100%|██████████| 391/391 [00:06<00:00, 64.90it/s]
Simulating...
100%|██████████| 200/200 [05:25<00:00,  1.63s/it]
SNN accuracy (simulation 500 time-steps): 0.4973
DingJianhao commented 1 year ago

try new algos I recommend you read Optimized Potential Initialization for Low-latency Spiking Neural Networks

sauravtii commented 1 year ago

Sure, I will read that. Thank you for your help! If possible, can you please tell me why the accuracy improved after using average pooling and why it wasn't improving while using max pooling ? I did try to search for the answer but wasn't able to get a satisfactory answer.