fangwei123456 / spikingjelly

SpikingJelly is an open-source deep learning framework for Spiking Neural Network (SNN) based on PyTorch.
https://spikingjelly.readthedocs.io
Other
1.37k stars 239 forks source link

performance degradation of vgg16 imagenet conversion model #289

Open Snow-Crash opened 2 years ago

Snow-Crash commented 2 years ago

Dear developers,

I trying to use the converter to convert a vgg16 pretrained on imagenet. The original model acc is around 71%, however after after conversion acc is 0.1%. I'd like to know if there's anything wrong with my code? and is such performance expected? I know conversion may not work well on deep models. Is it possible to retrain this model to recover acc? The spikingjelly i'm using is 0.0.0.0.12

Here is the code to reproduce.

from torchvision.models import vgg16,VGG16_Weights
net = vgg16(weights=VGG16_Weights).to(device)
model_converter = ann2snn.Converter(dataloader=test_data_loader, mode='max')
model_converter.device = device
snn_model = model_converter(net)
# val function is same as https://github.com/fangwei123456/spikingjelly/blob/0.0.0.0.12/spikingjelly/clock_driven/ann2snn/examples/cnn_mnist.py
val(snn_model, device, test_data_loader, T=4) 

Here is model print

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): Sequential(
      (0): VoltageScaler(0.056344)
      (1): IFNode(
        v_threshold=1.0, v_reset=None, detach_reset=False
        (surrogate_function): Sigmoid(alpha=4.0, spiking=True)
      )
      (2): VoltageScaler(17.748156)
    )
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): Sequential(
      (0): VoltageScaler(0.028887)
      (1): IFNode(
        v_threshold=1.0, v_reset=None, detach_reset=False
        (surrogate_function): Sigmoid(alpha=4.0, spiking=True)
      )
      (2): VoltageScaler(34.617970)
    )
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): Sequential(
      (0): VoltageScaler(0.019909)
      (1): IFNode(
        v_threshold=1.0, v_reset=None, detach_reset=False
        (surrogate_function): Sigmoid(alpha=4.0, spiking=True)
      )
      (2): VoltageScaler(50.229336)
    )
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): Sequential(
      (0): VoltageScaler(0.013216)
      (1): IFNode(
        v_threshold=1.0, v_reset=None, detach_reset=False
        (surrogate_function): Sigmoid(alpha=4.0, spiking=True)
      )
      (2): VoltageScaler(75.663567)
    )
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): Sequential(
      (0): VoltageScaler(0.009770)
      (1): IFNode(
        v_threshold=1.0, v_reset=None, detach_reset=False
        (surrogate_function): Sigmoid(alpha=4.0, spiking=True)
      )
      (2): VoltageScaler(102.357101)
    )
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): Sequential(
      (0): VoltageScaler(0.006533)
      (1): IFNode(
        v_threshold=1.0, v_reset=None, detach_reset=False
        (surrogate_function): Sigmoid(alpha=4.0, spiking=True)
      )
      (2): VoltageScaler(153.070221)
    )
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): Sequential(
      (0): VoltageScaler(0.004854)
      (1): IFNode(
        v_threshold=1.0, v_reset=None, detach_reset=False
        (surrogate_function): Sigmoid(alpha=4.0, spiking=True)
      )
      (2): VoltageScaler(206.007767)
    )
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (18): Sequential(
      (0): VoltageScaler(0.004345)
      (1): IFNode(
        v_threshold=1.0, v_reset=None, detach_reset=False
        (surrogate_function): Sigmoid(alpha=4.0, spiking=True)
      )
      (2): VoltageScaler(230.124420)
    )
    (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (20): Sequential(
      (0): VoltageScaler(0.008149)
      (1): IFNode(
        v_threshold=1.0, v_reset=None, detach_reset=False
        (surrogate_function): Sigmoid(alpha=4.0, spiking=True)
      )
      (2): VoltageScaler(122.713150)
    )
    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (22): Sequential(
      (0): VoltageScaler(0.009516)
      (1): IFNode(
        v_threshold=1.0, v_reset=None, detach_reset=False
        (surrogate_function): Sigmoid(alpha=4.0, spiking=True)
      )
      (2): VoltageScaler(105.091423)
    )
    (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (25): Sequential(
      (0): VoltageScaler(0.009928)
      (1): IFNode(
        v_threshold=1.0, v_reset=None, detach_reset=False
        (surrogate_function): Sigmoid(alpha=4.0, spiking=True)
      )
      (2): VoltageScaler(100.728149)
    )
    (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (27): Sequential(
      (0): VoltageScaler(0.012477)
      (1): IFNode(
        v_threshold=1.0, v_reset=None, detach_reset=False
        (surrogate_function): Sigmoid(alpha=4.0, spiking=True)
      )
      (2): VoltageScaler(80.145638)
    )
    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (29): Sequential(
      (0): VoltageScaler(0.012046)
      (1): IFNode(
        v_threshold=1.0, v_reset=None, detach_reset=False
        (surrogate_function): Sigmoid(alpha=4.0, spiking=True)
      )
      (2): VoltageScaler(83.013229)
    )
    (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): Sequential(
      (0): VoltageScaler(0.040641)
      (1): IFNode(
        v_threshold=1.0, v_reset=None, detach_reset=False
        (surrogate_function): Sigmoid(alpha=4.0, spiking=True)
      )
      (2): VoltageScaler(24.605427)
    )
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): Sequential(
      (0): VoltageScaler(0.065874)
      (1): IFNode(
        v_threshold=1.0, v_reset=None, detach_reset=False
        (surrogate_function): Sigmoid(alpha=4.0, spiking=True)
      )
      (2): VoltageScaler(15.180535)
    )
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)
fangwei123456 commented 2 years ago

@Lyu6PosHao

Lyu6PosHao commented 2 years ago

Dear friend, Having seen your code, I think the main reason why you get low accuracy is that parameter T=4 is too small. As we know, in the converted SNN model, firing rate is used to replace activation value of ReLU. A small T makes IF neurons fire rarely. Therefore, my suggestions are:

  1. Use a proper T, maybe T=50.
  2. If still unsatisfying, use other modes such as mode='99%', rather than 'max'.
Snow-Crash commented 2 years ago

Thanks for quick reply! I put my results here for reference, in case anyone also has similar problems.

Below results are returned by val function, each element is the result at a time step. Mode: max, T = 50

[0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.00403226,
       0.03427419, 0.0766129 , 0.12600806, 0.17540323, 0.20262097,
       0.23387097, 0.2671371 , 0.29737903, 0.31854839, 0.34274194,
       0.36391129, 0.37197581, 0.38810484, 0.41431452, 0.42842742,
       0.43850806, 0.45362903, 0.46370968, 0.47883065, 0.49596774,
       0.50907258, 0.52318548, 0.53225806, 0.54334677, 0.55141129,
       0.5625    , 0.56854839, 0.57358871, 0.58165323, 0.58669355]

Mode: 99%, T = 50

array([0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.00403226, 0.2953629 ,
       0.51008065, 0.55040323, 0.57762097, 0.58366935, 0.58064516,
       0.59274194, 0.59475806, 0.58770161, 0.58971774, 0.58770161,
       0.58971774, 0.58770161, 0.58971774, 0.58971774, 0.58971774,
       0.58770161, 0.58064516, 0.57762097, 0.57560484, 0.57358871,
       0.56552419, 0.56451613, 0.56149194, 0.56048387, 0.55141129,
       0.55241935, 0.55141129, 0.54939516, 0.54637097, 0.54233871,
       0.54435484, 0.54032258, 0.53830645, 0.53830645, 0.53427419,
       0.53427419, 0.53024194, 0.52822581, 0.52620968, 0.52520161])

As suggested by @Lyu6PosHao T has to be large enough. This is different from the case of directly training resnet on imagenet, T=4 is sufficient. When mode is max, acc generally increases as T becomes larger. However, when model is 99%, larger T may not help, and acc may becomes lower.