fangwei123456 / spikingjelly

SpikingJelly is an open-source deep learning framework for Spiking Neural Network (SNN) based on PyTorch.
https://spikingjelly.readthedocs.io
Other
1.35k stars 239 forks source link

Memory issue arise from neuronal_reset #310

Open gwgknudayanga opened 1 year ago

gwgknudayanga commented 1 year ago

Hi, I am always getting the following error from neuronal_reset. Is there any clue?

/spikingjelly/activation_based/neuron.py", line 198, in neuronal_reset RuntimeError: The following operation failed in the TorchScript interpreter. Traceback of TorchScript (most recent call last): RuntimeError: The following operation failed in the TorchScript interpreter. Traceback of TorchScript (most recent call last): RuntimeError: The following operation failed in the TorchScript interpreter. Traceback of TorchScript (most recent call last): File ".anaconda3/envs/snn_train/lib/python3.9/site-packages/spikingjelly-0.0.0.0.13-py3.9.egg/spikingjelly/activation_based/neuron.py", line 127, in fallback_cuda_fuser @torch.jit.script def jit_hard_reset(v: torch.Tensor, spike: torch.Tensor, v_reset: float): v = (1. - spike) v + spike v_reset


        return v
RuntimeError: CUDA out of memory. Tried to allocate 20.00 MiB (GPU 0; 7.92 GiB total capacity; 6.78 GiB already allocated; 24.44 MiB free; 6.79 GiB reser
mountains-high commented 1 year ago

Hi,

Did you call functional.reset_net(model) ? Previously, I also faced a similar issue. The problem has solved after resetting the net. Hope this solves your issue.

fangwei123456 commented 1 year ago

image

gwgknudayanga commented 1 year ago

Thank you both for the answer. But in my code, i have called reset_net() function after the optimization step. However when i reduce the batch size from 16 to 8 this problem did not arise. But reducing the batch size causes to reduce the accuracy also. So is there a way to handle this issue? thanks !

mountains-high commented 1 year ago

Hi, in that case, maybe you could also play with T (timestep), which takes a crucial role in accuracy and allocation.

However, it looks strange, if you have reset, the above error should not have to be thrown. I do not think it's because of the batch size, just in my experience, I did not face such an issue.

fangwei123456 commented 1 year ago

when i reduce the batch size from 16 to 8 this problem did not arise.

The GPU memory consumption is proportional to T * N, where T is the number of time-steps and N is the batch size. You can try to reduce T to N.

gwgknudayanga commented 1 year ago

Hi both, thanks for reply. By reducing the time steps that problem doesn't occur. But accuracy also drops. Here i am using images of 224x224 resolution and use poisson coding to generate spikes for that. batch size = 16. When i set the time steps above 10, this memory issue comes. call functional.reset_net(net) at the end of step() function of pytorch lightning module. Thanks .