rkim35 / spikeRNN

MIT License
34 stars 24 forks source link
machine-learning recurrent-neural-networks spiking-neural-networks

Functional Spiking RNNs

Overview

This repository provides the code for the framework presented in this paper:

Kim R., Li Y., & Sejnowski TJ. Simple Framework for Constructing Functional Spiking Recurrent Neural Networks. Proceedings of the National Academy of Sciences. 116: 22811-22820 (2019).

Preprint available here.

Requirements

Continuous Rate RNN

The code for constructing and training continuous-variable rate recurrent neural network (RNN) models is implemented in Python (tested in Python 3.6.9). The code also requires TensorFlow (tested in TensorFlow 1.5.0 and 1.10.0).

Spiking RNN

The code for constructing spiking RNN models is implemented in MATLAB (tested in R2016a and R2016b). The code implements leaky integrate-and-fire (LIF) networks, and is a modified version of the code developed by Nicola et al., 2016.

Usage

A rate RNN model is trained first, and the trained model is then mapped to a LIF spiking RNN. The code for training rate models is located in rate/, while the code for mapping and constructing LIF models is in spiking/.

Training Continuous Rate RNN

The main file (rate/main.py) takes the following input arguments:

The following example trains a rate model to perform the Go-NoGo task. The network contains 200 units (20% of the units are inhibitory). The training will stop if the termination criteria are met within the first 5000 trials (n_trials). No additional connectivity constraints are used (i.e. som_N is set to 0). The trained model will be saved as a MATLAB-formatted file (.mat) in the output directory (../models/go-nogo/P_rec_0.20_Taus_4.0_20.0).

python main.py --gpu 0 --gpu_frac 0.20 \
--n_trials 5000 --mode train \
--N 200 --P_inh 0.20 --som_N 0 --apply_dale True\
--gain 1.5 --task go-nogo --act sigmoid --loss_fn l2\
--decay_taus 4 20 --output_dir ../

The name of the output .mat file conforms to the following convention:

Task_<Task Name>_N_<N>_Taus_<min_tau>_<max_tau>_Act_<act>_<YYYY_MM_DD_TIME>.mat

Mapping and Constructing LIF RNN

Trained rate RNNs are used to construct LIF RNNs. The mapping and LIF simulations are performed in MATLAB. Given a trained rate model, the first step is to perform the grid search to determine the optimal scaling factor (lambda). This is done by lambdad_grid_search.m. Once the optimal scaling factor is determined, a LIF RNN can be constructed using the function LIF_network_fnc.m. All the required functions/scripts are located in spiking/.

An example script for evaluating a Go-NoGo LIF network (eval_go_nogo.m) is also included. The script constructs a LIF RNN trained to perform the Go-NoGo task and plots network responses. The script can be modified to evaluate models trained to perform other tasks.

Citation

If you use this repo for your research, please cite our work:

@article{Kim_2019,
    Author = {Kim, Robert and Li, Yinghao and Sejnowski, Terrence J.},
    Doi = {10.1073/pnas.1905926116},
    Journal = {Proceedings of the National Academy of Sciences},
    Number = {45},
    Pages = {22811--22820},
    Publisher = {National Academy of Sciences},
    Title = {Simple framework for constructing functional spiking recurrent neural networks},
    Volume = {116},
    Year = {2019}}