fangwei123456 / spikingjelly

SpikingJelly is an open-source deep learning framework for Spiking Neural Network (SNN) based on PyTorch.
https://spikingjelly.readthedocs.io
Other
1.24k stars 235 forks source link

采用spikingjelly框架里的IF神经元替换nn.ReLU()计算神经网络的雅可比矩阵时出现错误 #488

Open YMX-zknu opened 6 months ago

YMX-zknu commented 6 months ago

我建立了一个测试网络(如下述代码),将其中的nn.ReLU()替换为spikingjelly框架里的IF神经元,然后采用torch.func.jacrev函数计算该网络的雅可比矩阵时,发生如下报错:RuntimeError: In order to use an autograd.Function with functorch transforms (vmap, grad, jvp, jacrev, ...), it must override the setup_context staticmethod. 采用nn.ReLU()时不会报错。请问该如何解决这个问题?

import torch
import torch.nn as nn
from spikingjelly.activation_based import layer,neuron,surrogate
class net(nn.Module):

    def __init__(self):
        super(net, self).__init__()

        self.conv1 = nn.Sequential(         
            layer.Conv2d(1, 16, 5, 1, 2),       
            # nn.ReLU(), 
            neuron.IFNode(surrogate_function=surrogate.Sigmoid(),step_mode='s'),                    
            layer.MaxPool2d(kernel_size=2)) # try nn.AvgPool2d instead

        self.FC = nn.Linear(16 * 14 * 14, 10)

    def forward(self, x):

        x = self.conv1(x)
        x = x.flatten(start_dim=1)     
        output = self.FC(x)

        return output

# initialize model
model = net()
def model_single_input(x):      # forward pass on single input (image)
    return model(x.unsqueeze(dim=0))

# generate batch of input data
batch_size = 100
scale = 0.001    # smaller scale results in larger discrepancy between two methods
X = scale * torch.rand(size=(batch_size, 1, 28, 28,)) # data in MNIST format

# compute jacrev with vmap
jacrev_vmap = torch.vmap(torch.func.jacrev(model_single_input))(X)
jacrev_vmap.shape
fangwei123456 commented 6 months ago

jvp之类的函数似乎不知道自定义的autograd function

https://github.com/pytorch/functorch/issues/207