Closed humingjie8 closed 8 months ago
发现在计算神经元的spike_rate的循环过程中,每经历一次神经元模拟步长的计算后,显存的占用就会明显提升0.4GB左右
计算的时候忘了开no grad模式?
非常感激您的回复,我刚刚初步尝试了一下启用no grad模式,但是在训练是都是报错显示:
element 0 of tensors does not require grad and does not have a grad_fn
File "D:\code\vo2_sim_snn\approaches\ewc.py", line 198, in train_epoch
loss.backward()
File "D:\code\vo2_sim_snn\approaches\ewc.py", line 106, in train
self.train_epoch(t,ncla,xtrain,ytrain, e)
File "D:\code\vo2_sim_snn\main.py", line 114, in
老师您能不能指出应该在具体的训练的代码中什么位置上开启no grad,谢谢老师,下面附上我具体的训练代码:
def train_epoch(self,t,ncla,x,y, epoch): self.model.train()
r=np.arange(x.size(0))
np.random.shuffle(r)
r=torch.LongTensor(r).cuda()
# Loop batches
for i in range(0,len(r),self.sbatch):
if i + self.sbatch > len(r):
# 如果剩余数据不足一个batch,跳过这些数据
continue
b = r[i:i+self.sbatch]
# print('num_b:',len(b))
images=x[b]
targets=y[b]
# Forward current model
# input_spike = self.Poisson_encoder(images/255.0)
spike_num_img = torch.zeros(len(b), ncla).cuda()
for time in range(args.T_sim):
outputs = self.model.forward(images)[t]
with torch.no_grad(): # 增添后运行报错
spike_num_img += outputs
spike_rate = spike_num_img/args.T_sim
# print('spike_rate-std:',torch.std(spike_rate))
loss=self.criterion(t,spike_rate,targets)
# Backward
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
with torch.no_grad():
self.model.reset_() # reset the neuron voltage every batch, to ensure independency between batchs
del loss
del images, targets, outputs, spike_num_img, spike_rate
return
如果是作为损失计算的话,不能用 no grad模式 在训练过程中,显存占用是稳定的吗?如果没有持续增加,则说明代码没问题
显存在一个batch训练完成后,就会保持稳定,之后的训练不会增加了,我想知道有方法可以减少显存的占用吗?我感觉他似乎是把每一个模拟步长计算中的的梯度信息都保留了,但我不知道这样是否是合理的,或者说是否有更好的方法可以提高计算效率在面对SNN中这种极度耗费显存资源的方法。谢谢老师。
我想知道有方法可以减少显存的占用吗?我感觉他似乎是把每一个模拟步长计算中的的梯度信息都保留了
保留所有步长的信息是BPTT,如果想降低显存占用可以考虑RTRL,但训练出的网络性能通常会下降不少
好的,谢谢老师
您好,我是一名使用SNN进行模拟计算的学生,最近在使用自己之前的工作代码进行SNN模型训练中,发现在计算神经元的spike_rate的循环过程中,每经历一次神经元模拟步长的计算后,显存的占用就会明显提升0.4GB左右,对此感觉非常苦恼,因为显存的大量占用导致batch_size无法设置的过大,训练时间耗费过长。因为时间问题,我没有办法再重新使用spikingjelly重新开始写代码,所以想请教spikingjelly在训练过程中是如何处理神经元的一系列相关数据,以提升在训练中的计算效率的。非常感谢您能花费时间阅读我的问题,希望能得到spikingjelly开发人员的指导和解惑。
下面附上我使用的神经元相关代码。 class LIFNeuron(nn.Module): ''' LIF Neuron model, parameters are extracted from experimental data. Rd: device HRS Cm: parallel capacitor Rs: series resistance Vth: threshold voltage of Neuron device V_reset: reset voltage of Neuron v: membrane potential dt: simulation time step ''' def init(self, batch_size, dim_in, Rd=5.0e3, Cm=3.0e-6, Rs=1.5, Vth=0.8, V_reset=0.0, dt=1.0e-6,): super().init() self.batch_size = batch_size self.dim_in = tuple(dim_in) self.rd = Rd self.cm = Cm self.rs = Rs self.vth = Vth self.v_reset = V_reset self.v = torch.full([self.batch_size, *self.dim_in], self.v_reset, dtype=torch.float, device=device) # tensor full of the v_reset self.dt = dt