fangwei123456 / spikingjelly

SpikingJelly is an open-source deep learning framework for Spiking Neural Network (SNN) based on PyTorch.
https://spikingjelly.readthedocs.io
Other
1.22k stars 233 forks source link

STDP学习器step函数存在数值类型异常 #522

Open LumenScopeAI opened 3 months ago

LumenScopeAI commented 3 months ago

STDP学习器step函数存在数值类型异常

def step(self, on_grad: bool = True, scale: float = 1.):
        length = self.in_spike_monitor.records.__len__()
        delta_w = None

        if self.step_mode == 's':
            if isinstance(self.synapse, nn.Linear):
                stdp_f = stdp_linear_single_step
            elif isinstance(self.synapse, nn.Conv2d):
                stdp_f = stdp_conv2d_single_step
            elif isinstance(self.synapse, nn.Conv1d):
                stdp_f = stdp_conv1d_single_step
            else:
                raise NotImplementedError(self.synapse)
        elif self.step_mode == 'm':
            if isinstance(self.synapse, (nn.Linear, nn.Conv1d, nn.Conv2d)):
                stdp_f = stdp_multi_step
            else:
                raise NotImplementedError(self.synapse)
        else:
            raise ValueError(self.step_mode)

        for _ in range(length):
            in_spike = self.in_spike_monitor.records.pop(0)     # [batch_size, N_in]
            out_spike = self.out_spike_monitor.records.pop(0)   # [batch_size, N_out]

            self.trace_pre, self.trace_post, dw = stdp_f(
                self.synapse, in_spike, out_spike,
                self.trace_pre, self.trace_post, 
                self.tau_pre, self.tau_post,
                self.f_pre, self.f_post
            )
            if scale != 1.:
                dw *= scale

            delta_w = dw if (delta_w is None) else (delta_w + dw)

        if on_grad:
            if self.synapse.weight.grad is None:
                self.synapse.weight.grad = -delta_w
            else:
                self.synapse.weight.grad = self.synapse.weight.grad - delta_w
                # if delta_w is not None:
                #     self.synapse.weight.grad = self.synapse.weight.grad - delta_w
        else:
            return delta_w

我的处理方式是加入:

                if delta_w is not None:
                    self.synapse.weight.grad = self.synapse.weight.grad - delta_w

此时可以解决数据类型异常,但训练loss不变,请问如何解决?

fangwei123456 commented 2 months ago

这个问题比较奇怪,因为监视器实际上不会管数据是否为脉冲,都完整的记录下来。监视器没有记录到数据,可能是被监视的那个层实际上没有参与网络的计算?

fangwei123456 commented 2 months ago

另外训练loss不变是很正常的,stdp是非常弱的无监督学习器,如果很容易就能调出好的性能,那就是重大科学突破了