kmheckel / spyx

Spyx: Spiking Neural Networks in JAX
https://spyx.readthedocs.io/en/latest/
MIT License
98 stars 11 forks source link

Adjust Neuron Models in spyx.nn to store constant betas as hk.params #18

Closed kmheckel closed 7 months ago

kmheckel commented 7 months ago

Currently if the user specifies the inverse time constant/beta value it will not be tracked in the PyTree for the network, making the layer invisible when trying to export it to NIR for cross platform function.

Each neuron model needs an "else" clause that calls hk.get_parameter() but with the init argument set to the user specified value in order to fix this.

See the fixed LI neuron as an example of what needs to be done for the other neuron models (except for IF... This will need a different solution/approach to be visible.)