jeshraghian / snntorch

Deep and online learning with spiking neural networks in Python
https://snntorch.readthedocs.io/en/latest/
MIT License
1.36k stars 226 forks source link

GPU memory leakage with Leaky neuron #328

Open jagdap opened 5 months ago

jagdap commented 5 months ago

Description

Memory usage in GPU grows unexpectedly when using Leaky() neuron.

What I Did

After noticing unexpected memory growth with my SNN, I tracked down an memory leakage/garbage collect issue with the Leaky neuron, demonstrated with the code below:

# SNN -- Leaky neuron memory leakage
import snntorch as snn
import torch as nn

for i in range(3):
    l = snn.Leaky(beta=0.5, spike_grad=snn.surrogate.fast_sigmoid(slope=25)).to("cuda")
    del l

    nn.cuda.empty_cache()
    print("\n")
    print("mem alloc:", nn.cuda.memory_allocated())
    print("max mem alloc:", nn.cuda.max_memory_allocated())

print("=======================")

for i in range(3):
    l = snn.Leaky(beta=0.5, spike_grad=snn.surrogate.fast_sigmoid(slope=25)).to("cuda")
    del l
    nn.cuda.empty_cache()

    print("\n")
    print("mem alloc:", nn.cuda.memory_allocated())
    print("max mem alloc:", nn.cuda.max_memory_allocated())

Output:

mem alloc: 2048
max mem alloc: 4198400

mem alloc: 4096
max mem alloc: 4198400

mem alloc: 6144
max mem alloc: 4198400
=======================

mem alloc: 8192
max mem alloc: 4198400

mem alloc: 10240
max mem alloc: 4198400

mem alloc: 12288
max mem alloc: 4198400

Notice that the memory usage on the GPU is growing even when explicitly deleting the neuron.

Expected Behavior

We would expect that dereferencing or using del would remove the Leaky neuron from all memory. This is the behavior observed with the torch.nn.Linear neuron, for example:

# SNN -- Leaky neuron memory leakage
import snntorch as snn
import torch as nn

for i in range(3):
    l = nn.nn.Linear(1024, 1024).to("cuda")
    del l

    nn.cuda.empty_cache()
    print("\n")
    print("mem alloc:", nn.cuda.memory_allocated())
    print("max mem alloc:", nn.cuda.max_memory_allocated())

print("=======================")

for i in range(3):
    l = nn.nn.Linear(1024, 1024).to("cuda")
    del l
    nn.cuda.empty_cache()

    print("\n")
    print("mem alloc:", nn.cuda.memory_allocated())
    print("max mem alloc:", nn.cuda.max_memory_allocated())

Output:

mem alloc: 0
max mem alloc: 4198400

mem alloc: 0
max mem alloc: 4198400

mem alloc: 0
max mem alloc: 4198400
=======================

mem alloc: 0
max mem alloc: 4198400

mem alloc: 0
max mem alloc: 4198400

mem alloc: 0
max mem alloc: 4198400

I suspect part of the issue is how instances are handled and tracked by the snntorch.SpikingNeuron class, but the behavior is observed even after deleting all instances.

jagdap commented 5 months ago

Digging a little deeper, I think there are two issues at play. First, SpikingNeuron.instances is being populated but not cleared, so the class variable is growing each time I create a new network with a SpikingNeuron child, causing memory usage to grow. This is especially painful when using Jupyter Notebook, since I tend to re-run cells that create models.

The other issue might be a PyTorch issue, but I'm not sure yet. One of the things done in Leaky is "selection" of which function will be used for a reset mechanism. This is assigned via self.state_function = self._base_sub (or whichever is selected). The act of assigning a function to a variable is creating some sort of memory leak. I've recreated the problem below:

class myModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.register_buffer("test", torch.as_tensor(1.0))
        self.which_foo = self.foo

    def forward(self, x):
        return x

    def foo(self, x):
        return

for _ in range(3):
    m = myModule().cuda()

    del m
    torch.cuda.empty_cache()
    print("\n")
    print("mem alloc:", torch.cuda.memory_allocated())
    print("max mem alloc:", torch.cuda.max_memory_allocated())

Output:


mem alloc: 512
max mem alloc: 512

mem alloc: 1024
max mem alloc: 1024

mem alloc: 1536
max mem alloc: 1536

Of note, the error only occurs when I've registered a buffer. Strangely, if I uncomment the self.register_buffer line, there's no longer an issue.