fangwei123456 / spikingjelly

SpikingJelly is an open-source deep learning framework for Spiking Neural Network (SNN) based on PyTorch.
https://spikingjelly.readthedocs.io
Other
1.4k stars 244 forks source link

在推理时,无法进入自定义的LIF函数 #519

Open grance199981 opened 7 months ago

grance199981 commented 7 months ago

from spikingjelly.activation_based import neuron

class SpikeTrace_LIF_Neuron(neuron.LIFNode): def init(self, tau: float = math.exp(1.0)/(math.exp(1.0)-1), decay_input: bool = False, v_threshold: float = 1., v_reset: float = None, surrogate_function: Callable = None, detach_reset: bool = False, step_mode='s', backend='torch', store_v_seq: bool = False): super().init(tau, decay_input, v_threshold, v_reset, surrogate_function, detach_reset, step_mode, backend,store_v_seq)

def neuronal_charge(self, x: torch.Tensor):
    if self.v_reset is None or self.v_reset == 0:
        if type(self.v) is float:
            self.v = x
        else:
            self.v = self.v.detach() * (1 - 1. / self.tau) + x
    else:
        if type(self.v) is float:
            # self.v = self.v_reset * self.tau + self.v_reset*(1-self.tau) + x
            self.v = self.v_reset * (1 - 1. / self.tau) + self.v_reset / self.tau + x
        else:
            # self.v = self.v.detach() * self.tau + self.v_reset*(1-self.tau) + x
            self.v = self.v.detach() * (1 - 1. / self.tau) + self.v_reset / self.tau + x
fangwei123456 commented 7 months ago

可能是原生的LIF的jit函数导致的

建议把你自行实现的LIF,基类换成 https://github.com/fangwei123456/spikingjelly/blob/c21e21cf3679cf5edc344e949eb160f13db2b55b/spikingjelly/activation_based/neuron.py#L70