fangwei123456 / spikingjelly

SpikingJelly is an open-source deep learning framework for Spiking Neural Network (SNN) based on PyTorch.
https://spikingjelly.readthedocs.io
Other
1.24k stars 235 forks source link

如何实现抑制和自适应阈值? #466

Open Clutcheee opened 8 months ago

Clutcheee commented 8 months ago

我正尝试在MNIST数据集上仅使用STDP训练单层全连接SNN,最终通过投票来对输出脉冲进行解码和分类。 我打印了out_fr,并调试了很久参数,但训练结束后神经元对不同类别的响应区别都不大,不能完成分类任务。 现在我想复现论文《Unsupervised learning of digit recognition using spike-timing-dependent plasticity》,但我不知道要如何实现抑制和自适应阈值(神经元每次放电后,放电阈值提高)。我看以前的问题回复中有提到:SJ框架中不区分兴奋型和抑制型,具体看突触的符号就行了。 请问要如何限制某些突触的符号始终为正/负?要实现自适应阈值需要自定义神经元吗?

fangwei123456 commented 7 months ago

我在最新的框架中增加了一个简化的基类神经元,便于用户实现不同的新神经元,下面是自适应神经元的一个例子。这个例子中,每次释放脉冲都会使神经元的阈值增加。

import torch
from spikingjelly.activation_based import neuron

class AdaptThresholdIFNode(neuron.SimpleBaseNode):
    def __inti__(self, *args, **kwargs):
        super().__inti__(*args, **kwargs)
        self.v_threshold_base = self.v_threshold
        # 阈值变成了动态变化的隐状态,因此也需要成为“记忆”
        del self.v_threshold
        self.register_memory('v_threshold', 0.)

    def neuronal_charge(self, x: torch.Tensor):
        self.v = self.v + x

    def neuronal_fire(self):
        spike = super().neuronal_fire()
        self.v_threshold = self.v_threshold + spike
        return spike

from matplotlib import pyplot as plt
net = AdaptThresholdIFNode()
T = 128
x = torch.ones([T]) * 0.1
v = []
v_th = []
for t in range(T):
    net(x[t])
    v.append(net.v)
    v_th.append(net.v_threshold)

plt.plot(torch.arange(T), v, label='v')
plt.plot(torch.arange(T), v_th, label='v_th')
plt.legend()
plt.show()
fangwei123456 commented 7 months ago

image

fangwei123456 commented 7 months ago

在SJ框架论文补充材料的fig12有个更复杂的例子 https://arxiv.org/abs/2310.16620

fangwei123456 commented 7 months ago

如何限制某些突触的符号始终为正/负?

例如给linear层限制为正,只需要给权重加个abs

import torch
import torch.nn as nn
from torch import Tensor

class PositiveLinear(nn.Linear):
    def forward(self, input: Tensor) -> Tensor:
        self.weight = torch.abs(self.weight)
        return super().forward(input)
fangwei123456 commented 7 months ago

STDP不收敛是正常的,无监督的学习方法很难训练出结果,能不能调出来看运气。