Closed CloudyDory closed 4 months ago
Thanks for the question. I will add examples and tutorials for you.
Currently, customizing a surrograte
gradient function can be implemented in the following ways.
For example, if we want to replace the Heaviside function $\mathcal{H}(v)$ with the surrogate function of sigmoid
, we can code with:
import brainpy.math as bm
import jax.numpy as jnp
from jax.scipy.special import expit
class MySurrogate(bm.Surrogate):
def __init__(self, alpha=1.):
super().__init__()
self.alpha = alpha
def surrogate_fun(self, x):
return sci.special.expit(x)
def surrogate_grad(self, x):
sgax = sci.special.expit(x * self.alpha)
dx = (1. - sgax) * sgax * self.alpha
return dx
where surrogate_fun
define the forward surrogate function $g(x) = \mathrm{sigmoid}(\alpha x) = \frac{1}{1+e^{-\alpha x}}$, and surrogate_grad
define the gradient of the surrogate function $g'(x) = \alpha * (1 - \mathrm{sigmoid} (\alpha x)) \mathrm{sigmoid} (\alpha x)$.
Hope this message can help you.
It is important to note here the surrogate gradient function defined in brainpy can not only support the reverse mode autograd such as jax.grad
, brainpy.math.grad
, and jax.vjp
, but also support the forward mode autograd including jax.jvp
, brainpy.math.fwdjac
and others. Compared to other libraries, this feature is unique in brainpy.
Thank you very much for the information! What about using @jax.custom_gradient
function decorator? I have found some BrainPy examples using this method.
Great question. @jax.custom_gradient
can only support reverse mode autograd. For online learning methods we are going to release, forward mode autograd is important when supporting the surrogate gradient of spikes.
It is important to note here the surrogate gradient function defined in brainpy can not only support the reverse mode autograd such as
jax.grad
,brainpy.math.grad
, andjax.vjp
, but also support the forward mode autograd includingjax.jvp
,brainpy.math.fwdjac
and others. Compared to other libraries, this feature is unique in brainpy.
This is interesting. Is the forward mode autograd actually being used during normal neural net training routine?
Actually not, or very few examples. However for biological plausibility and scalability, forward grad is much better.
For a matmul $x W = y$, the forward grad does not require the transpose of $W$ when computing $\partial y / \partial x$. however, reverse grad needs. Therefore, for distributed computing and biological plausibility (the brain does not implement the weight transpose), forward gradients are what we need.
Hi, is there a guide on writing custom gradient functions in BrainPy (such as those defined in
brainpy.math.surrogate
)?