I would like to subclass Leaky to modify its base function. The reason is that I would like beta to be negative. This would make the equation closer to how drift diffusion models are defined in psychology.
What I Did
I tried to subclass snn.Leaky and override the _base_state_function method
#!/usr/bin/env python3
import torch
import torch.nn as nn
import snntorch as snn
class MyLeaky(snn.Leaky):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def _base_state_function(self, input_):
base_fn = -self.beta.clamp(0, 1) * self.mem + input_ # only changed the sign
return base_fn
class Net(nn.Module):
def __init__(self, beta=0.5):
super().__init__()
self.neuron = MyLeaky(beta=beta)
def forward(self, x, mem=None):
if mem is None:
mem = self.neuron.init_leaky()
return self.neuron(x, mem)
However, when I called the class (and the forward method is triggered)
xs = torch.randn(10)
net = Net()
net(xs)
I got
TypeError: MyLeaky._base_state_function() takes 2 positional arguments but 3 were given
because the _base_state_function gets also mem as part of the inputs for some reasons.
Description
I would like to subclass
Leaky
to modify its base function. The reason is that I would likebeta
to be negative. This would make the equation closer to how drift diffusion models are defined in psychology.What I Did
I tried to subclass
snn.Leaky
and override the_base_state_function
methodHowever, when I called the class (and the
forward
method is triggered)I got
because the
_base_state_function
gets alsomem
as part of the inputs for some reasons.Would you be able to help?