Open LumenScopeAI opened 3 months ago
self.synapse.weight.grad = self.synapse.weight.grad - delta_w
length = self.in_spike_monitor.records.__len__()
for _ in range(length):
delta_w
on_grad
self.synapse.weight.grad
def step(self, on_grad: bool = True, scale: float = 1.): length = self.in_spike_monitor.records.__len__() delta_w = None if self.step_mode == 's': if isinstance(self.synapse, nn.Linear): stdp_f = stdp_linear_single_step elif isinstance(self.synapse, nn.Conv2d): stdp_f = stdp_conv2d_single_step elif isinstance(self.synapse, nn.Conv1d): stdp_f = stdp_conv1d_single_step else: raise NotImplementedError(self.synapse) elif self.step_mode == 'm': if isinstance(self.synapse, (nn.Linear, nn.Conv1d, nn.Conv2d)): stdp_f = stdp_multi_step else: raise NotImplementedError(self.synapse) else: raise ValueError(self.step_mode) for _ in range(length): in_spike = self.in_spike_monitor.records.pop(0) # [batch_size, N_in] out_spike = self.out_spike_monitor.records.pop(0) # [batch_size, N_out] self.trace_pre, self.trace_post, dw = stdp_f( self.synapse, in_spike, out_spike, self.trace_pre, self.trace_post, self.tau_pre, self.tau_post, self.f_pre, self.f_post ) if scale != 1.: dw *= scale delta_w = dw if (delta_w is None) else (delta_w + dw) if on_grad: if self.synapse.weight.grad is None: self.synapse.weight.grad = -delta_w else: self.synapse.weight.grad = self.synapse.weight.grad - delta_w # if delta_w is not None: # self.synapse.weight.grad = self.synapse.weight.grad - delta_w else: return delta_w
我的处理方式是加入:
if delta_w is not None: self.synapse.weight.grad = self.synapse.weight.grad - delta_w
此时可以解决数据类型异常,但训练loss不变,请问如何解决?
这个问题比较奇怪,因为监视器实际上不会管数据是否为脉冲,都完整的记录下来。监视器没有记录到数据,可能是被监视的那个层实际上没有参与网络的计算?
另外训练loss不变是很正常的,stdp是非常弱的无监督学习器,如果很容易就能调出好的性能,那就是重大科学突破了
STDP学习器step函数存在数值类型异常
self.synapse.weight.grad = self.synapse.weight.grad - delta_w
,在以下完整代码中,length = self.in_spike_monitor.records.__len__()
值为0,代表本轮并未接收到脉冲信号,即不会进入for _ in range(length):
循环,此时delta_w
恒为None,但on_grad
且self.synapse.weight.grad
不为None是,就会出现数据类型异常,张量在和None做减法。我的处理方式是加入:
此时可以解决数据类型异常,但训练loss不变,请问如何解决?