lava-nc / lava-dl

Deep Learning library for Lava
https://lava-nc.org
BSD 3-Clause "New" or "Revised" License
149 stars 71 forks source link

TypeError when using adrf neurons #277

Closed felixthoe closed 3 months ago

felixthoe commented 8 months ago

Describe the bug There seems to be a small bug in the implementation of the adrf neuron (lava-dl/src/lib/dl/slayer/neuron/adrf.py). The Neuron.forward() function calls self.spike(real, imag, threshold + refractory), while it should be self.spike(real, imag, threshold, refractory) according to the definition of the function Neuron.spike() in the same file.

To reproduce current behavior Steps to reproduce the behavior:

  1. When I run this code...
    import lava.lib.dl.slayer as slayer
    import torch
    # choose arbitrary neuron parameters
    my_neuron = slayer.neuron.adrf.Neuron(
    threshold = 1,
    threshold_step = 0.2,
    period = 4,
    decay = 0.3,
    threshold_decay = 0.3,
    refractory_decay = 0.3
    )
    # some random, complex input of shape ((N,C,T),(N,C,T))
    input = (torch.randn((16, 10, 50)), torch.randn((16,10,50)))
    output = my_neuron(input)
  2. I get this error ...
    File "/home/ujrrg/miniconda3/envs/lava/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
    File "/home/ujrrg/miniconda3/envs/lava/lib/python3.9/site-packages/lava/lib/dl/slayer/neuron/adrf.py", line 677, in forward
    return self.spike(real, imag, threshold + refractory)
    TypeError: spike() missing 1 required positional argument: 'refractory'
PhilippPlank commented 4 months ago

Great catch. Feel free to create a pull request for this fix :)