fangwei123456 / spikingjelly

SpikingJelly is an open-source deep learning framework for Spiking Neural Network (SNN) based on PyTorch.
https://spikingjelly.readthedocs.io
Other
1.35k stars 239 forks source link

functional.multi_step_forward显存占用后不释放 #425

Closed Met4physics closed 1 year ago

Met4physics commented 1 year ago

Issue type

SpikingJelly version

latest

Description 如题,我想测试一下ann2snn的效果,在dataloader读出一个以后,显存暴涨20个G,不知道哪里内存泄露了

Minimal code to reproduce the error/bug

T = 16
batch_size = 2
model_converter = ann2snn.Converter(mode='max', dataloader=train_dataloader)
model = Model(in_channels=3, num_classes=num_classes, base_c=32)
model.to(device)
snn_model = model_converter(model)
snn_model.to(device)
print(snn_model)

loss_weight = torch.as_tensor([1.0, 2.0], device=device)

losses = []
i = 0
for x, y in train_dataloader:
    x = x.repeat(T, 1, 1, 1, 1)
    x = x.to(device)
    y = y.to(device)
    outputs = functional.multi_step_forward(x, snn_model)
    outputs = outputs.mean(0)
    loss = criterion(outputs, y, loss_weight, dice=True, num_classes=num_classes, ignore_index=255)
    losses.append(loss.item())
    functional.reset_net(snn_model)
    print(i)

错误输出:

Traceback (most recent call last):
  File "/root/DRIVE/ann2snn.py", line 49, in <module>
    outputs = functional.multi_step_forward(x, snn_model)
  File "/root/miniconda3/lib/python3.8/site-packages/spikingjelly/activation_based/functional.py", line 563, in multi_step_forward
    y_seq.append(single_step_module(x_seq[t]))
  File "/root/miniconda3/lib/python3.8/site-packages/torch/fx/graph_module.py", line 630, in wrapped_call
    raise e.with_traceback(None)
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
RuntimeError: CUDA out of memory. Tried to allocate 30.00 MiB (GPU 0; 23.70 GiB total capacity; 20.86 GiB already allocated; 16.56 MiB free; 21.69 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

已经确认了dataloader没有内存泄漏,而且从train_loader里取出来的第一个是运行下去了,第二个就报错了。我调试出来循环第一遍结束时显存占用是20g,第二个读出来时上一个的显存占用没有释放,这是为什么呢?

fangwei123456 commented 1 year ago

pytorch中默认的计算是带梯度的,计算图会一直保存至内存中 加上with torch.no_grad()

Met4physics commented 1 year ago

pytorch中默认的计算是带梯度的,计算图会一直保存至内存中 加上with torch.no_grad()

也就是说如果我想用surrogate gradient decent这种bp方法来训练,而不是ann2snn,必须要承担这种内存消耗吗?

fangwei123456 commented 1 year ago

是的

Met4physics commented 1 year ago

好的,感谢!

fangwei123456 commented 1 year ago

消耗只和batch size和T有关,上面的代码中,整个训练集的计算图都保存了下来,这是不合理的。正常训练每次bakcward后计算图就摧毁了,内存中最大的计算图只是单个batch数据的

Met4physics commented 1 year ago

消耗只和batch size和T有关,上面的代码中,整个训练集的计算图都保存了下来,这是不合理的。正常训练每次bakcward后计算图就摧毁了,内存中最大的计算图只是单个batch数据的

明白了,我确实没执行loss.backward(),应该是这个原因