kmheckel / spyx

Spyx: Spiking Neural Networks in JAX
https://spyx.readthedocs.io/en/latest/
MIT License
102 stars 12 forks source link

Add SPSN #5

Open kmheckel opened 1 year ago

kmheckel commented 1 year ago

Add Stochastic Parallelizable Spiking Neuron model.

Paper: https://arxiv.org/abs/2306.12666#:~:text=In%20this%20paper%2C%20we%20propose,run%20in%20parallel%20over%20time.

Torch implementation: https://github.com/NECOTIS/Stochastic-Parallelizable-Spiking-Neuron-SPSN/blob/main/neurons/spsn.py

kmheckel commented 10 months ago

https://arxiv.org/abs/2401.00955

Spiking State Space Machine Paper.

kmheckel commented 9 months ago

A prototype implementation has been implemented in spyx.experimental. Need to verify the dimensions of internal calculation as right now there's broadcasting issues if the output dimension isn't a multiple/fraction of the input dim.

Training accuracy on SHD comparable to regularly trained recurrent SNNs in Spyx. Performance is comparable to recurrent model at small scales tested locally with short sequences (64). Testing notebook is at research/SPSN.

Considering creating a stochastic Axon class for _SigmoidBernoulli and refactoring it.

kmheckel commented 9 months ago

Found this:

Parallel Spiking Unit: https://arxiv.org/abs/2402.00449