fangwei123456 / spikingjelly

SpikingJelly is an open-source deep learning framework for Spiking Neural Network (SNN) based on PyTorch.
https://spikingjelly.readthedocs.io
Other
1.25k stars 236 forks source link

如何实现TTFS的神经元 #403

Open FishSeeker opened 1 year ago

FishSeeker commented 1 year ago

我尝试使用latencyencoder来encode input,但是在神经元的选择上,sj里并没有现成的适用于ttfs的neuron类型,我尝试了issue里面的方法来进行实现并没有成功。能否问一下如何实现ttfs的神经元

fangwei123456 commented 1 year ago

具体怎么实现需要看你如何定义这个神经元的动态,然后对框架中现有的神经元做出相应的修改(参考神经元的教程,里面有写如何定义新神经元)

下面这个老版本的讨论也有一些帮助: https://github.com/fangwei123456/spikingjelly/discussions/45

Spice-monkey commented 1 year ago

您好,我是想验证延时编码的数据的训练效果,我可能遇到了和您类似的问题,就是针对这类数据应该有TTFS这种神经元来接受数据,但是不知道如何设计这种神经元

fangwei123456 commented 1 year ago

下面这个只释放一次脉冲的IF神经元应该就可以

@FishSeeker @Spice-monkey

import torch
from spikingjelly.activation_based import neuron

class TTFSIFNode(neuron.BaseNode):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.register_memory('fire_mask', 0.)

    def neuronal_charge(self, x: torch.Tensor):
        self.v = self.v + x

    def neuronal_fire(self):
        self.spike = self.surrogate_function(self.v - self.v_threshold) * (1. - self.fire_mask)
        self.fire_mask = self.spike + self.fire_mask
Spice-monkey commented 1 year ago

好的,十分感谢,我去试试,还有一个疑问,我感觉time to first spike 编码和latency 编码很类似,是我的错觉吗?或者是我理解有问题吗,可以指教吗

fangwei123456 commented 1 year ago

这2原理确实是一样的

Spice-monkey commented 1 year ago

好的,十分感谢!但是我看到有的论文指出,对于越强的刺激,触发时间越快,但是也有的说越强的刺激触发时间越晚,A Survey of Encoding Techniques for Signal Processing in Spiking Neural Networks https://link.springer.com/article/10.1007/s11063-021-10562-2这篇论文提到两种方式并不一样,令我有些疑惑

fangwei123456 commented 1 year ago

SNN里面用延迟编码一般都是输入越大,发放脉冲越早

Spice-monkey commented 1 year ago

嗯嗯,我看到大多数也是这样的,主要是我们想设计一种感知器件,但是测试发现似乎是越大的输入发放脉冲越晚,不知道比如说简单的取个倒数再归一化能不能达到编码的需求

Spice-monkey commented 1 year ago

延时编码,频率编码,还有增量编码,这三种方式的神经元设计是不是相差很大?或者以您的经验说三者最后的表现有哪些优劣之分呢?

fangwei123456 commented 1 year ago

增量编码不了解 ANN2SNN用频率编码,性能是目前最好的。延迟编码的神经元确实是需要单独设计的

Spice-monkey commented 1 year ago

好的!感谢您的回复,我试着用ttfs来实现一下编码到识别 ,有问题再与您交流

Spice-monkey commented 1 year ago

您好,我观察到您使用mask来确保已经释放的神经元不再释放,但是这种神经元是接受延时编码的格式数据吗?我理解的输入数据格式是每个像素的灰度对应不同的触发时间,越亮的触发越早,那么是不是一个时间步所有神经元都触发且仅被触发一次?但是我感觉您的示例代码构建的神经元似乎还是要经过多个时间步的刺激,才能确保所有神经元都触发一次?我感觉这种神经元的输出是不是应该是触发的时间而不是单纯的01数据?或者说二者都记录,但是用于推理的其实是触发时间?还请您指教

Spice-monkey commented 1 year ago

具体怎么实现需要看你如何定义这个神经元的动态,然后对框架中现有的神经元做出相应的修改(参考神经元的教程,里面有写如何定义新神经元)

下面这个老版本的讨论也有一些帮助: #45

我学习了您在该项目下的回答,但是似乎是需要多个时间步来让所有神经元触发?或者说这个时间步的长度是怎么确定的?如何确定在某个时间步总长下所有神经元都会被触发?

fangwei123456 commented 1 year ago

TTFS都是只能释放一次脉冲的。除非你定义了一种新的编码方式,用多个脉冲来表示某个值

但是似乎是需要多个时间步来让所有神经元触发?或者说这个时间步的长度是怎么确定的?如何确定在某个时间步总长下所有神经元都会被触发?

TTFS需要跑完整个T步。不能保证所有神经元都能触发,会出现dead neuron的问题。有些论文会做一些修改,比如强制不释放的神经元在最后一个时刻释放脉冲

FishSeeker commented 11 months ago

@fangwei123456 你好,我尝试使用了你在这个问题底下提供的ttfs的代码,但是它提示有问题,我不太确定是因为什么

Traceback (most recent call last):
  File "ttfs.py", line 284, in <module>
    main()
  File "ttfs.py", line 179, in main
    out_fr += net(encoded_img)
  File "/home/miao/anaconda3/envs/jelly/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "ttfs.py", line 48, in forward
    return self.layer(x)
  File "/home/miao/anaconda3/envs/jelly/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/miao/anaconda3/envs/jelly/lib/python3.8/site-packages/torch/nn/modules/container.py", line 217, in forward
    input = module(input)
  File "/home/miao/anaconda3/envs/jelly/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/miao/anaconda3/envs/jelly/lib/python3.8/site-packages/spikingjelly/activation_based/base.py", line 268, in forward
    return self.single_step_forward(*args, **kwargs)
  File "/home/miao/anaconda3/envs/jelly/lib/python3.8/site-packages/spikingjelly/activation_based/neuron.py", line 243, in single_step_forward
    self.neuronal_reset(spike)
  File "/home/miao/anaconda3/envs/jelly/lib/python3.8/site-packages/spikingjelly/activation_based/neuron.py", line 207, in neuronal_reset
    self.v = self.jit_hard_reset(self.v, spike_d, self.v_reset)
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
  File "<string>", line 11, in jit_hard_reset
  return b == a
def sub(a : float, b : Tensor) -> Tensor:
  return torch.neg(b) + a
         ~~~~~~~~~ <--- HERE
def div(a : float, b : Tensor) -> Tensor:
  return torch.reciprocal(b) * a
RuntimeError: Expected a proper Tensor but got None (or an undefined Tensor in C++) for argument #0 'self'
fangwei123456 commented 11 months ago

前面的代码忘了返回spike,修复一下:

import torch
from spikingjelly.activation_based import neuron, functional, surrogate

class TTFSIFNode(neuron.BaseNode):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.register_memory('fire_mask', 0.)

    def neuronal_charge(self, x: torch.Tensor):
        self.v = self.v + x

    def neuronal_fire(self):
        self.spike = self.surrogate_function(self.v - self.v_threshold) * (1. - self.fire_mask)
        self.fire_mask = self.spike + self.fire_mask
        return self.spike

net = TTFSIFNode(v_threshold = 1., v_reset = 0., surrogate_function = surrogate.Sigmoid())

x_seq = torch.rand([16, 8])

functional.multi_step_forward(x_seq, net)
FishSeeker commented 11 months ago

前面的代码忘了返回spike,修复一下:

import torch
from spikingjelly.activation_based import neuron, functional, surrogate

class TTFSIFNode(neuron.BaseNode):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.register_memory('fire_mask', 0.)

    def neuronal_charge(self, x: torch.Tensor):
        self.v = self.v + x

    def neuronal_fire(self):
        self.spike = self.surrogate_function(self.v - self.v_threshold) * (1. - self.fire_mask)
        self.fire_mask = self.spike + self.fire_mask
        return self.spike

net = TTFSIFNode(v_threshold = 1., v_reset = 0., surrogate_function = surrogate.Sigmoid())

x_seq = torch.rand([16, 8])

functional.multi_step_forward(x_seq, net)

非常感谢,可以work了!