kmheckel / spyx

Spyx: Spiking Neural Networks in JAX
https://spyx.readthedocs.io/en/latest/
MIT License
98 stars 11 forks source link

Add SPSN #5

Open kmheckel opened 11 months ago

kmheckel commented 11 months ago

Add Stochastic Parallelizable Spiking Neuron model.

Paper: https://arxiv.org/abs/2306.12666#:~:text=In%20this%20paper%2C%20we%20propose,run%20in%20parallel%20over%20time.

Torch implementation: https://github.com/NECOTIS/Stochastic-Parallelizable-Spiking-Neuron-SPSN/blob/main/neurons/spsn.py

kmheckel commented 7 months ago

https://arxiv.org/abs/2401.00955

Spiking State Space Machine Paper.

kmheckel commented 7 months ago

A prototype implementation has been implemented in spyx.experimental. Need to verify the dimensions of internal calculation as right now there's broadcasting issues if the output dimension isn't a multiple/fraction of the input dim.

Training accuracy on SHD comparable to regularly trained recurrent SNNs in Spyx. Performance is comparable to recurrent model at small scales tested locally with short sequences (64). Testing notebook is at research/SPSN.

Considering creating a stochastic Axon class for _SigmoidBernoulli and refactoring it.

kmheckel commented 6 months ago

Found this:

Parallel Spiking Unit: https://arxiv.org/abs/2402.00449