kmheckel / spyx

Spyx: Spiking Neural Networks in JAX
https://spyx.readthedocs.io/en/latest/
MIT License
98 stars 11 forks source link

neuron models for e-prop implementation #28

Open florian6973 opened 5 months ago

florian6973 commented 5 months ago

Hi!

I was working on https://www.nature.com/articles/s41467-020-17236-y, so I added two classes based on the original code of the paper: RecurrentLIFLight and LeakyLinear to use the same LIF/ALIF neurons and the output neurons from the paper.

I included a test notebook shd.ipynb in the test folder.

This is only a first draft, I am open to any feedback, hoping it can help.

Best,

florian6973

kmheckel commented 5 months ago

Hi Florian!

This looks awesome! I'll give it a closer look in a couple days but I'm excited for the contribution!

I gave things a quick skim and it looks really good and gave me some ideas for minor fixes elsewhere in the library such as changing the spyx.heaviside implementation to use jnp.greater instead of jnp.where, which could be a slightly more efficient op (need to check what the lowering looks like).

I don't think the LeakyLinear layer will be necessary since it can be implemented using a hk.linear layer followed by a snn.LI neuron layer; this is the same approach taken by snnTorch where the (1-tau) factor associated with the input is absorbed into the weight matrix. If being able to have a fixed/non-trainable leakage constant is desired I definitely support adding stop_grad options to the neurons throughout the library to freeze the inverse time constant parameters. The ability to freeze constants associated with the model is something I've been wanting to do but the simple solution of stop_grad hadn't dawned on me.

Once I look things over more thoroughly I'll probably suggest a few minor naming edits just to mesh things with the rest of the library but overall I'm extremely excited to integrate this as a new capability!

If there's any other aspects of the library you have questions about or if you have feedback/ideas I'd love to hear them as well!

florian6973 commented 5 months ago

Hi Kade!

Thank you very much for your positive feedback, I'm glad it can help! Yes I find spyx very promising!

Good point indeed! I added this LeakyLinear class with a specific constant to check my implementation by comparing it with the results of the code from the original paper, but you are right it makes more sense to combine Linear + LI with potentially an option to freeze weights by stopping the gradient.

Yes sure for the small edits, I will also add some documentation (docstrings + tutorial?) but I wanted to make sure this was relevant first :)

Of course, thanks! Currently I am quite busy but I should spend more time on it in April ;)

Best,

kmheckel commented 4 months ago

Hi, just wanted to follow up on this!

florian6973 commented 4 months ago

Hi! Sorry for the delay, working on it this weekend (LeakyLinear refactoring + documentation)

kmheckel commented 4 months ago

No worries, just wanted to check in on things and see if there's any other way I can support!

florian6973 commented 4 months ago

I have been making some progress, I will finish adding tests and cleaning the related notebook this week. I am not sure I will have time to implement constant freezing unfortunately.

kmheckel commented 4 months ago

Awesome, looking forward to it!