jeshraghian / snntorch

Deep and online learning with spiking neural networks in Python
https://snntorch.readthedocs.io/en/latest/
MIT License
1.28k stars 217 forks source link

Subclass Leaky #299

Closed JackCaster closed 6 months ago

JackCaster commented 6 months ago

Description

I would like to subclass Leaky to modify its base function. The reason is that I would like beta to be negative. This would make the equation closer to how drift diffusion models are defined in psychology.

What I Did

I tried to subclass snn.Leaky and override the _base_state_function method

#!/usr/bin/env python3
import torch
import torch.nn as nn
import snntorch as snn

class MyLeaky(snn.Leaky):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def _base_state_function(self, input_):
        base_fn = -self.beta.clamp(0, 1) * self.mem + input_ # only changed the sign
        return base_fn

class Net(nn.Module):
    def __init__(self, beta=0.5):
        super().__init__()

        self.neuron = MyLeaky(beta=beta)

    def forward(self, x, mem=None):
        if mem is None:
            mem = self.neuron.init_leaky()

        return self.neuron(x, mem)

However, when I called the class (and the forward method is triggered)

xs = torch.randn(10)
net = Net()

net(xs)

I got

TypeError: MyLeaky._base_state_function() takes 2 positional arguments but 3 were given

because the _base_state_function gets also mem as part of the inputs for some reasons.

Would you be able to help?