fangwei123456 / spikingjelly

SpikingJelly is an open-source deep learning framework for Spiking Neural Network (SNN) based on PyTorch.
https://spikingjelly.readthedocs.io
Other
1.37k stars 239 forks source link

how to use our own/custom surrogate gradient function instead of using already existing surogate gradient function #268

Closed gwgknudayanga closed 2 years ago

gwgknudayanga commented 2 years ago

Hi ,

  1. how to define our own surrogate gradient function and attach it to the back propagation algorithm for Spiking Neural Networks training?
  2. In addition to learning weights of the spiking neural networks, if i want to learn some of the parameters of that surrogate gradient function with respect to the loss function, how that can be done?

Thank you.

Thanks and Rgds, Kashita

fangwei123456 commented 2 years ago

Hi, surrogate_function can be any callable function or module. Here is an example:

import torch
import torch.nn as nn
from spikingjelly.activation_based import neuron, surrogate

class ste(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, alpha):
        if x.requires_grad:
            ctx.save_for_backward(x)
            ctx.alpha = alpha
        return surrogate.heaviside(x)

    @staticmethod
    def backward(ctx, grad_output):
        x = ctx.saved_tensors[0]
        mask = (x.abs() <= ctx.alpha).to(x)
        return grad_output * mask, None

class STE(nn.Module):
    def __init__(self, alpha: float = 1.):
        super(STE, self).__init__()
        self.alpha = alpha

    def forward(self, x):
        return ste.apply(x, self.alpha)

net = neuron.IFNode(surrogate_function=STE())

x = torch.rand([4, 8], requires_grad=True)

net(x).sum().backward()
print(x.grad)
fangwei123456 commented 2 years ago

if i want to learn some of the parameters of that surrogate gradient function with respect to the loss function

Then you should define how the gradient of the param is calculated and set the param as nn.Parameter. Here is an example:

import torch
import torch.nn as nn
from spikingjelly.activation_based import neuron, surrogate

class ste(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, alpha):
        if x.requires_grad:
            ctx.save_for_backward(x)
            ctx.alpha = alpha
        return surrogate.heaviside(x)

    @staticmethod
    def backward(ctx, grad_output):
        x = ctx.saved_tensors[0]
        mask = (x.abs() <= ctx.alpha).to(x)
        return grad_output * mask, x.abs().mean()

class STE(nn.Module):
    def __init__(self, alpha: float = 1.):
        super(STE, self).__init__()
        self.alpha = nn.Parameter(torch.as_tensor(alpha))

    def forward(self, x):
        return ste.apply(x, self.alpha)

net = neuron.IFNode(surrogate_function=STE())

x = torch.rand([4, 8], requires_grad=True)

net(x).sum().backward()
print(x.grad)

print(net.surrogate_function.alpha.grad)
gwgknudayanga commented 2 years ago

Thanks a lot @fangwei123456 !

gwgknudayanga commented 2 years ago

Hi,

When i set the device for input tensor 'ẍ́ and network 'net' to cuda(), then the gradient of x is always None. So how can i use the the execute this code in GPU appropriately? thanks,

import torch import torch.nn as nn from spikingjelly.activation_based import neuron, surrogate

class ste(torch.autograd.Function): @staticmethod def forward(ctx, x, alpha): if x.requires_grad: ctx.save_for_backward(x) ctx.alpha = alpha return surrogate.heaviside(x)

@staticmethod
def backward(ctx, grad_output):
    print(grad_output)
    x = ctx.saved_tensors[0]
    print(x)
    mask = (x.abs() <= ctx.alpha).to(x)
    return grad_output * mask, x.abs().mean()

class STE(nn.Module): def init(self, alpha: float = 1.): super(STE, self).init() self.alpha = nn.Parameter(torch.as_tensor(alpha))

def forward(self, x):
    return ste.apply(x, self.alpha)

net = neuron.IFNode(surrogate_function=STE()) net = net.cuda()

x = torch.rand([4, 8], requires_grad=True) x = x.cuda() net(x).sum().backward() print(x.grad)

fangwei123456 commented 2 years ago

Duplicated in https://github.com/fangwei123456/spikingjelly/issues/280

gwgknudayanga commented 2 years ago

Hi,

Is it possible to access the parameters relevant to the neuron within it surrogate function? For example, in this following code within the surrogate function , i want to know what is the current voltage and the current time constant of the ParametricLIF neuron 'mylif'. How those paramers can be queried within its surrogate_function? Thanks !

mylif = ParametricLIFNode(surrogate_function=STE())

class ste(torch.autograd.Function): @staticmethod def forward(ctx, x, alpha): if x.requires_grad: ctx.save_for_backward(x) ctx.alpha = alpha return surrogate.heaviside(x)

@staticmethod
def backward(ctx, grad_output):
    x = ctx.saved_tensors[0]
    mask = (x.abs() <= ctx.alpha).to(x)
    return grad_output * mask, x.abs().mean()

class STE(nn.Module): def init(self, alpha: float = 1.): super(STE, self).init() self.alpha = nn.Parameter(torch.as_tensor(alpha))

def forward(self, x):
    return ste.apply(x, self.alpha)
fangwei123456 commented 2 years ago

Is it possible to access the parameters relevant to the neuron within it surrogate function?

You can return the data and print or save the returned data.