fangwei123456 / spikingjelly

SpikingJelly is an open-source deep learning framework for Spiking Neural Network (SNN) based on PyTorch.
https://spikingjelly.readthedocs.io
Other
1.4k stars 244 forks source link

请教有关SNN在训练过程中的显存占用问题。 #514

Closed humingjie8 closed 8 months ago

humingjie8 commented 8 months ago

您好,我是一名使用SNN进行模拟计算的学生,最近在使用自己之前的工作代码进行SNN模型训练中,发现在计算神经元的spike_rate的循环过程中,每经历一次神经元模拟步长的计算后,显存的占用就会明显提升0.4GB左右,对此感觉非常苦恼,因为显存的大量占用导致batch_size无法设置的过大,训练时间耗费过长。因为时间问题,我没有办法再重新使用spikingjelly重新开始写代码,所以想请教spikingjelly在训练过程中是如何处理神经元的一系列相关数据,以提升在训练中的计算效率的。非常感谢您能花费时间阅读我的问题,希望能得到spikingjelly开发人员的指导和解惑。

下面附上我使用的神经元相关代码。 class LIFNeuron(nn.Module): ''' LIF Neuron model, parameters are extracted from experimental data. Rd: device HRS Cm: parallel capacitor Rs: series resistance Vth: threshold voltage of Neuron device V_reset: reset voltage of Neuron v: membrane potential dt: simulation time step ''' def init(self, batch_size, dim_in, Rd=5.0e3, Cm=3.0e-6, Rs=1.5, Vth=0.8, V_reset=0.0, dt=1.0e-6,): super().init() self.batch_size = batch_size self.dim_in = tuple(dim_in) self.rd = Rd self.cm = Cm self.rs = Rs self.vth = Vth self.v_reset = V_reset self.v = torch.full([self.batch_size, *self.dim_in], self.v_reset, dtype=torch.float, device=device) # tensor full of the v_reset self.dt = dt

    self.tau_in = 1/(self.cm*self.rs)
    self.tau_lk = 1/(self.cm)*(1/self.rd + 1/self.rs) 

@staticmethod
def soft_spike(x):
    a = 2.0
    return torch.sigmoid_(a*x)

def spiking(self):
    if self.training == True:
        spike_hard = torch.gt(self.v, self.vth).float()
        # spike_hard = fire.go(self.v)
        spike_soft = self.soft_spike(self.v - self.vth)
        v_hard = self.v_reset*spike_hard + self.v*(1 - spike_hard)
        v_soft = self.v_reset*spike_soft + self.v*(1 - spike_soft)
        self.v = v_soft + (v_hard - v_soft).detach_()
        return spike_soft + (spike_hard - spike_soft).detach_()
    else:
        spike_hard = torch.gt(self.v, self.vth).float()
        # spike_hard = fire.go(self.v)
        self.v = self.v_reset*spike_hard + self.v*(1 - spike_hard)
        return spike_hard

def forward(self, v_inject):
    '''
    Upgrade membrane potention every time step by differantial equation.
    '''
    # print('v_inject:',v_inject.size())
    # print('v:',self.v.size())
    self.v += (self.tau_in*v_inject - self.tau_lk*self.v) * self.dt
    return self.spiking()

def reset(self):
    '''
    Reset the membrane potential to reset voltage.
    '''
    self.v = torch.full([self.batch_size, *self.dim_in], self.v_reset, dtype=torch.float, device=device)  # tensor full of the v_reset
fangwei123456 commented 8 months ago

发现在计算神经元的spike_rate的循环过程中,每经历一次神经元模拟步长的计算后,显存的占用就会明显提升0.4GB左右

计算的时候忘了开no grad模式?

humingjie8 commented 8 months ago

非常感激您的回复,我刚刚初步尝试了一下启用no grad模式,但是在训练是都是报错显示: element 0 of tensors does not require grad and does not have a grad_fn File "D:\code\vo2_sim_snn\approaches\ewc.py", line 198, in train_epoch loss.backward() File "D:\code\vo2_sim_snn\approaches\ewc.py", line 106, in train self.train_epoch(t,ncla,xtrain,ytrain, e) File "D:\code\vo2_sim_snn\main.py", line 114, in appr.train(task,ncla, xtrain, ytrain, xvalid, yvalid, data, inputsize, taskcla) RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

老师您能不能指出应该在具体的训练的代码中什么位置上开启no grad,谢谢老师,下面附上我具体的训练代码:

def train_epoch(self,t,ncla,x,y, epoch): self.model.train()

self.model.eval()

  r=np.arange(x.size(0))
  np.random.shuffle(r)
  r=torch.LongTensor(r).cuda()

  # Loop batches
  for i in range(0,len(r),self.sbatch):
      if i + self.sbatch > len(r):
          # 如果剩余数据不足一个batch,跳过这些数据
          continue

      b = r[i:i+self.sbatch]
      # print('num_b:',len(b))
      images=x[b]
      targets=y[b]

      # Forward current model
      # input_spike = self.Poisson_encoder(images/255.0)
      spike_num_img = torch.zeros(len(b), ncla).cuda()

      for time in range(args.T_sim):
          outputs = self.model.forward(images)[t]

          with torch.no_grad(): # 增添后运行报错
              spike_num_img += outputs

      spike_rate = spike_num_img/args.T_sim 
      # print('spike_rate-std:',torch.std(spike_rate))           
      loss=self.criterion(t,spike_rate,targets)

      # Backward
      self.optimizer.zero_grad()
      loss.backward()
      self.optimizer.step()

      with torch.no_grad():
          self.model.reset_()    # reset the neuron voltage every batch, to ensure independency between batchs

      del loss
      del images, targets, outputs, spike_num_img, spike_rate        

  return
fangwei123456 commented 8 months ago

如果是作为损失计算的话,不能用 no grad模式 在训练过程中,显存占用是稳定的吗?如果没有持续增加,则说明代码没问题

humingjie8 commented 8 months ago

显存在一个batch训练完成后,就会保持稳定,之后的训练不会增加了,我想知道有方法可以减少显存的占用吗?我感觉他似乎是把每一个模拟步长计算中的的梯度信息都保留了,但我不知道这样是否是合理的,或者说是否有更好的方法可以提高计算效率在面对SNN中这种极度耗费显存资源的方法。谢谢老师。

fangwei123456 commented 8 months ago

我想知道有方法可以减少显存的占用吗?我感觉他似乎是把每一个模拟步长计算中的的梯度信息都保留了

保留所有步长的信息是BPTT,如果想降低显存占用可以考虑RTRL,但训练出的网络性能通常会下降不少

humingjie8 commented 8 months ago

好的,谢谢老师