fangwei123456 / spikingjelly

SpikingJelly is an open-source deep learning framework for Spiking Neural Network (SNN) based on PyTorch.
https://spikingjelly.readthedocs.io
Other
1.24k stars 235 forks source link

为每一个神经元训练一个自定义的参数 #501

Open USTCYYX opened 4 months ago

USTCYYX commented 4 months ago

我希望为每一个神经元训练一个自定义的参数a,用来调节输入电压的大小,具体的写法是这样的

class NIFNode(neuron.BaseNode):
    def __init__(self,  v_threshold: float = 1.,
                 v_reset: float = 0., surrogate_function: Callable = surrogate.Sigmoid(),
                 detach_reset: bool = False, step_mode='s', backend='torch', store_v_seq: bool = False):

        super().__init__(v_threshold, v_reset, surrogate_function, detach_reset, step_mode, backend, store_v_seq)              
        init_a = 0.       
        self.a = nn.Parameter(torch.tensor(init_a, dtype=torch.float))

    def neuronal_charge(self, x: torch.Tensor):
        self.v = self.v + x

    def single_step_forward(self, x: torch.Tensor):
        self.v_float_to_tensor(x)
        x=x*torch.sigmoid(self.a)
        self.neuronal_charge(x)     
        spike = self.neuronal_fire()
        self.neuronal_reset(spike)

        return spike 

事实上,这个写法参照了PLIF的源码。 但是,在训练过程中,我用如下代码输出训练的参数

for param in net.state_dict():
    print(param, net.state_dict()[param].size())

发现训练的a没有size

conv_fc.7.a torch.Size([])

我不确定这样训练有没有问题。如果真的训练了,那么应该可以看到a的size,并且这个size和神经元的size一样。 希望您能告诉我我这么训练有没有问题,以及应该怎么样看到a的训练后的参数。

fangwei123456 commented 4 months ago

nn.Parameter必须在初始化时就指定shape,不能像神经元的v那样根据输入的shape去动态生成了。

没有size是正常的,因为shape=[1]和shape=[]是有区别的:

import torch

x = torch.as_tensor(1.)
y = torch.as_tensor([1.])

print(x, x.shape)
print(y, y.shape)

输出是

tensor(1.) torch.Size([])
tensor([1.]) torch.Size([1])
fangwei123456 commented 4 months ago

如果要为每一个神经元指定一个参数,通常要把参数初始化成

self.a = nn.Parameter(a_init)

其中a_init是参数初始值,shape应该与这一层神经元的shape相同。例如卷积层后的神经元,如果卷积层输出shape=[N, C, H, W]则这里的参数shape应该是[C, H, W].

USTCYYX commented 4 months ago

感谢您的回复! 似乎并没有一个很好的方法自适应的为每一个神经元训练一个参数,只能提前指定size。我的想法是这样:

NIFNode(surrogate_function=surrogate.ATan(), a_size=[C,H,W]),
init_a = torch.zeros(a_size)       
self.a = nn.Parameter(init_a)
fangwei123456 commented 4 months ago

自适应是不行的,只能提前制定好