jeshraghian / snntorch

Deep and online learning with spiking neural networks in Python
https://snntorch.readthedocs.io/en/latest/
MIT License
1.35k stars 222 forks source link

Custom Surrogate Gradient Function #237

Closed mehranfaraji closed 1 year ago

mehranfaraji commented 1 year ago

CustomSurrogate class and custom_surrogate function added to the surrogate.py file.

This enables the users to define their own custom surrogate function and call it using either custom_surrogate(name_of_custom_surrogate_function) or CustomSurrogate.apply(data, name_of_custom_surrogate_function)

The arguments of the custom surrogate gradient function always are The input of the forward pass (input_), the gradient of the input (grad_input) and the output of the the forward pass (out) respectively.

Important Note: The hyperparameters of the custom Surrogate gradient function have to be defined inside of the function itself.

jeshraghian commented 1 year ago

This is really elegant, thanks for this contribution!

I've just done a commit to fix up some minor docstrings, and also added a section in the documentation about custom surrogate gradients.