pkuxmq / OTTT-SNN

[NeurIPS 2022] Online Training Through Time for Spiking Neural Networks
54 stars 10 forks source link

OTTT-SNN

This is the PyTorch implementation of paper: Online Training Through Time for Spiking Neural Networks (NeurIPS 2022). [arxiv][openreview].

Update 2023/12: Some modules of OTTT have been integrated into the latest code of spikingjelly, and the new codes for the neuron model can support multi-gpu training. We provide the reference codes included in the spikingjelly repository in spikingjelly_codes/reference_codes/, where neuron.py, layer.py, and functional.py (located in spikingjelly/activation_based/ in their repository) include some modules for OTTT, and spiking_vggws_ottt.py (located in spikingjelly/activation_based/model/ in their repository) gives an example of how to define the model with OTTT modules. We also provide an example of how to train the model in spikingjelly_codes/train_ottt_cifar.py for reference.

Dependencies and Installation

Training

For OTTT$_A$, run as following:

python train_cifar.py -data_dir path_to_data_dir -dataset cifar10 -out_dir log_checkpoint_name -gpu-id 0

# For VGG-F model
python train_cifar.py -data_dir path_to_data_dir -dataset cifar100 -out_dir log_checkpoint_name -gpu-id 0 -model online_spiking_vgg11f_ws

python train_cifar10dvs.py -data_dir path_to_data_dir -out_dir log_checkpoint_name -gpu-id 0

python train_imagenet.py -data_dir path_to_data_dir -out_dir log_checkpoint_name -gpu-id 0

For OTTT$_O$, add the argument -online_update as:

python train_cifar.py -data_dir path_to_data_dir -dataset cifar10 -out_dir log_checkpoint_name -gpu-id 0 -online_update

The default hyperparameters in the code are the same as in the paper.

Note: Current codes only support single-gpu training.

Testing

We provide the example code to calculate the firing rate statistics during evaluation. Run as following:

python get_rate_cifar.py -data_dir path_to_data_dir -dataset cifar10 -gpu-id 0 -resume path_to_checkpoint

python get_rate_imagenet.py -data_dir path_to_data_dir -gpu-id 0 -resume path_to_checkpoint

Some pretrained models can be downloaded from Google Drive or Baidu Drive (extraction code: gppq).

Acknowledgement

Some codes for the neuron model and data prepoccessing are adapted from the spikingjelly repository, and the codes for some utils are from the pytorch-classification repository.

Contact

If you have any questions, please contact mingqing_xiao@pku.edu.cn.