This repository is the official PyTorch implementation of the paper: Training High-Performance Low-Latency Spiking Neural Networks by Differentiation on Spike Representation (CVPR 2022).
pip install numpy matplotlib progress
Please run the following code. The hyperparameters in the code are the same as in the paper.
python -u cifar/ --path ./data --dataset cifar10 --model preresnet.resnet18_lif --name [checkpoint_name]
python -u cifar/ --path ./data --dataset cifar10 --model preresnet.resnet18_if --Vth 6 --alpha 0.5 --Vth_bound 0.01 --name [checkpoint_name]
python -u cifar/ --path ./data --dataset cifar100 --model preresnet.resnet18_lif --name [checkpoint_name]
python -u cifar/ --path ./data --dataset cifar100 --model preresnet.resnet18_if --Vth 6 --alpha 0.5 --Vth_bound 0.01 --name [checkpoint_name]
python -u cifar/ --path ./data/CIFAR10DVS --dataset CIFAR10DVS --model vgg.vgg11_lif --lr=0.05 --epochs=300 --name [checkpoint_name]
python -u cifar/ --path ./data/CIFAR10DVS --dataset CIFAR10DVS --model vgg.vgg11_if --Vth 6 --alpha=0.5 --Vth_bound 0.01 --lr=0.05 --epochs=300 --name [checkpoint_name]
DSR can achieve good results for low latency (e.g., T=5) by tuning the hyperparameters. The code for training with T=5 on CIFAR-10 is shown below. The accuracy is near 94.45% (Fig.3 in the paper). For other datasets, please reduce lr and tune Vth for better performance.
python -u cifar/ --path ./data --dataset cifar10 --model preresnet.resnet18_lif --timesteps 5 --lr 0.05 --Vth 0.6 --alpha 0.5 --Vth_bound 0.001 --delta_t 0.1
python -u cifar/ --path ./data --dataset cifar10 --model preresnet.resnet18_if --timesteps 5 --lr 0.05 --Vth 6 --alpha 0.5 --Vth_bound 0.01
For the CIFAR-10, CIFAR-100, and DVS-CIFAR10 tasks, multiple GPUs can also be used. The example code is shown below.
python -u -m torch.distributed.launch --nproc_per_node [number_of_gpus] cifar/ --path ./data --dataset cifar10 --model preresnet.resnet18_lif --name [checkpoint_name]
For the ImageNet classification task, we conduct hybrid training.
First, we train an ANN.
python imagenet/ --arch preresnet_ann.resnet18 --data ./data/imagenet --name model_ann --optimizer SGD --wd 1e-4 --batch-size 256 --lr 0.1
Then, we calculate the maximum post-activation as the initialization for spike thresholds.
python imagenet/ --arch preresnet_cal_Vth.resnet18 --data ./data/imagenet --pre_train model_ann.pth --calculate_Vth resnet18_Vth
Next, we train the SNN.
python imagenet/ --dist-url tcp:// --dist-backend nccl --multiprocessing-distributed --world-size 1 --rank 0 --arch preresnet_snn.resnet18_if --data ./data/imagenet --pre_train model_ann.pth --load_Vth resnet18_Vth.dict
The pretrained ANN model and calculated thresholds can be downloaded from here and here. Please put them in the path ./checkpoint/imagenet.
The code for the data preprocessing of DVS-CIFAR10 is based on the spikingjelly repo. The code for some utils are from the pytorch-classification repo.