brainpy / BrainPy

Brain Dynamics Programming in Python
https://brainpy.readthedocs.io/
GNU General Public License v3.0
491 stars 90 forks source link

Questions about writing custom gradient functions in BrainPy #637

Closed CloudyDory closed 4 months ago

CloudyDory commented 4 months ago

Hi, is there a guide on writing custom gradient functions in BrainPy (such as those defined in brainpy.math.surrogate)?

chaoming0625 commented 4 months ago

Thanks for the question. I will add examples and tutorials for you.

chaoming0625 commented 4 months ago

Currently, customizing a surrograte gradient function can be implemented in the following ways.

For example, if we want to replace the Heaviside function $\mathcal{H}(v)$ with the surrogate function of sigmoid, we can code with:


import brainpy.math as bm
import jax.numpy as jnp
from jax.scipy.special import expit

class MySurrogate(bm.Surrogate):
  def __init__(self, alpha=1.):
     super().__init__()
     self.alpha = alpha

  def surrogate_fun(self, x):
    return sci.special.expit(x)

  def surrogate_grad(self, x):
    sgax = sci.special.expit(x * self.alpha)
    dx = (1. - sgax) * sgax * self.alpha
    return dx

where surrogate_fun define the forward surrogate function $g(x) = \mathrm{sigmoid}(\alpha x) = \frac{1}{1+e^{-\alpha x}}$, and surrogate_grad define the gradient of the surrogate function $g'(x) = \alpha * (1 - \mathrm{sigmoid} (\alpha x)) \mathrm{sigmoid} (\alpha x)$.

chaoming0625 commented 4 months ago

Hope this message can help you.

chaoming0625 commented 4 months ago

It is important to note here the surrogate gradient function defined in brainpy can not only support the reverse mode autograd such as jax.grad, brainpy.math.grad , and jax.vjp, but also support the forward mode autograd including jax.jvp, brainpy.math.fwdjac and others. Compared to other libraries, this feature is unique in brainpy.

CloudyDory commented 4 months ago

Thank you very much for the information! What about using @jax.custom_gradient function decorator? I have found some BrainPy examples using this method.

chaoming0625 commented 4 months ago

Great question. @jax.custom_gradient can only support reverse mode autograd. For online learning methods we are going to release, forward mode autograd is important when supporting the surrogate gradient of spikes.

CloudyDory commented 4 months ago

It is important to note here the surrogate gradient function defined in brainpy can not only support the reverse mode autograd such as jax.grad, brainpy.math.grad , and jax.vjp, but also support the forward mode autograd including jax.jvp, brainpy.math.fwdjac and others. Compared to other libraries, this feature is unique in brainpy.

This is interesting. Is the forward mode autograd actually being used during normal neural net training routine?

chaoming0625 commented 4 months ago

Actually not, or very few examples. However for biological plausibility and scalability, forward grad is much better.

For a matmul $x W = y$, the forward grad does not require the transpose of $W$ when computing $\partial y / \partial x$. however, reverse grad needs. Therefore, for distributed computing and biological plausibility (the brain does not implement the weight transpose), forward gradients are what we need.