ali-rasteh / spikingjelly-develop

Other
1 stars 0 forks source link

ReadoutNode with v_threshold=inf #1

Open fangwei123456 opened 3 years ago

fangwei123456 commented 3 years ago

I think we can skip neuronal_fire when we set v_threshold=inf:

from spikingjelly.clock_driven import neuron, surrogate
import torch
import math
class ReadoutNode(neuron.BaseNode):
    def __init__(self, tau=100.0, v_threshold=math.inf, v_reset=0.0, surrogate_function=surrogate.Sigmoid(), detach_reset=False,
                 monitor_state=False):

        super().__init__(v_threshold, v_reset, surrogate_function, detach_reset, monitor_state)
        self.tau = tau

    def extra_repr(self):
        return f'v_threshold={self.v_threshold}, v_reset={self.v_reset}, tau={self.tau}'

    def neuronal_charge(self, dv: torch.Tensor):
        if self.v_reset is None:
            self.v += (dv - self.v) / self.tau
        else:
            self.v += (dv - (self.v - self.v_reset)) / self.tau

    def forward(self, dv: torch.Tensor):

        self.neuronal_charge(dv)
        if not math.isinf(self.v_threshold):
            self.neuronal_fire()
        self.neuronal_reset()
        return self.v
fangwei123456 commented 3 years ago

There is also a similar layer in SpikingJelly: https://spikingjelly.readthedocs.io/zh_CN/latest/spikingjelly.clock_driven.layer.html#lowpasssynapse-init-en

ali-rasteh commented 3 years ago

@fangwei123456 Correct! I didn't notice that layer in Spikingjelly! Thanks a lot Wei.