Open grance199981 opened 7 months ago
from spikingjelly.activation_based import neuron
class SpikeTrace_LIF_Neuron(neuron.LIFNode): def init(self, tau: float = math.exp(1.0)/(math.exp(1.0)-1), decay_input: bool = False, v_threshold: float = 1., v_reset: float = None, surrogate_function: Callable = None, detach_reset: bool = False, step_mode='s', backend='torch', store_v_seq: bool = False): super().init(tau, decay_input, v_threshold, v_reset, surrogate_function, detach_reset, step_mode, backend,store_v_seq)
def neuronal_charge(self, x: torch.Tensor): if self.v_reset is None or self.v_reset == 0: if type(self.v) is float: self.v = x else: self.v = self.v.detach() * (1 - 1. / self.tau) + x else: if type(self.v) is float: # self.v = self.v_reset * self.tau + self.v_reset*(1-self.tau) + x self.v = self.v_reset * (1 - 1. / self.tau) + self.v_reset / self.tau + x else: # self.v = self.v.detach() * self.tau + self.v_reset*(1-self.tau) + x self.v = self.v.detach() * (1 - 1. / self.tau) + self.v_reset / self.tau + x
可能是原生的LIF的jit函数导致的
建议把你自行实现的LIF,基类换成 https://github.com/fangwei123456/spikingjelly/blob/c21e21cf3679cf5edc344e949eb160f13db2b55b/spikingjelly/activation_based/neuron.py#L70
from spikingjelly.activation_based import neuron
class SpikeTrace_LIF_Neuron(neuron.LIFNode): def init(self, tau: float = math.exp(1.0)/(math.exp(1.0)-1), decay_input: bool = False, v_threshold: float = 1., v_reset: float = None, surrogate_function: Callable = None, detach_reset: bool = False, step_mode='s', backend='torch', store_v_seq: bool = False): super().init(tau, decay_input, v_threshold, v_reset, surrogate_function, detach_reset, step_mode, backend,store_v_seq)