lava-nc / lava

A Software Framework for Neuromorphic Computing
https://lava-nc.org
Other
529 stars 136 forks source link

Cannot turn off current decay in LIF neuron #789

Closed stevenabreu7 closed 9 months ago

stevenabreu7 commented 9 months ago

Describe the bug

I cannot run a standard LIF neuron without current decay. Setting the current decay to 1.0 does not disable it, as would be expected from the documentation and the code.

To reproduce current behavior

When I run this code

data = np.zeros((1, 100))
data[0, 50] = 1

ring_buffer = RingBuffer(data=data)
dense = Dense(weights=np.ones((1, 1)) * 0.5)

du = 1.0
lif = LIF(shape=(1,), du=du, dv=0.4, vth=0.1)

ring_buffer.s_out.connect(dense.s_in)
dense.a_out.connect(lif.a_in)

The membrane potential is never increased, see image below: _lava_issue_du1_0

When setting the current decay to a value very close to 1.0, e.g. 0.999, it does increase the membrane potential: _lava_issue_du=0_999

However, another issue is that the membrane potential gets different values for different values of current decay du, even if the current has the same trace. See below for du=0.9999: _lava_issue_du0_9999

Expected behavior

I would expect that one can turn off current decay completely by simply setting du=1.0, and then the current will be injected for only one timestep, with no leak/decay over future timesteps. The equations in the docstrings of the LIF process also describe such behavior:

    u[t] = u[t-1] * (1-du) + a_in         # neuron current
    v[t] = v[t-1] * (1-dv) + u[t] + bias  # neuron voltage
    s_out = v[t] > vth                    # spike if threshold is exceeded
    v[t] = 0                              # reset at spike

Moreover, I would expect the membrane potential to get updated to the same value given the same current traces.

Environment (please complete the following information):

Full script to reproduce (simply change the value of dv)

from lava.proc.monitor.process import Monitor
from lava.magma.core.run_conditions import RunSteps
from lava.magma.core.run_configs import Loihi2SimCfg, Loihi1SimCfg, Loihi2HwCfg, Loihi1HwCfg
from lava.proc.io.source import RingBuffer
from lava.proc.lif.process import LIF
from lava.proc.dense.process import Dense
import matplotlib.pyplot as plt
import numpy as np

data = np.zeros((1, 100))
data[0, 50] = 1

ring_buffer = RingBuffer(data=data)
dense = Dense(weights=np.ones((1, 1)) * 0.5)

du = 1.0 - 1e-4
# du = 1.0
lif = LIF(shape=(1,), du=du, dv=0.4, vth=0.1)

ring_buffer.s_out.connect(dense.s_in)
dense.a_out.connect(lif.a_in)

n_steps = 100

mon_volt = Monitor()
mon_curr = Monitor()
mon_inp = Monitor()
mon_volt.probe(lif.v, n_steps)
mon_curr.probe(lif.u, n_steps)
mon_inp.probe(ring_buffer.s_out, n_steps)
# lif.run(condition=RunSteps(num_steps=n_steps), run_cfg=Loihi1SimCfg())
lif.run(condition=RunSteps(num_steps=n_steps), run_cfg=Loihi2SimCfg())

fig, axs = plt.subplots(3, 1, figsize=(10, 7), sharex=True)
fig.suptitle(f'LIF neuron with du={du}')
mon_volt.plot(axs[0], lif.v)
mon_curr.plot(axs[1], lif.u)
mon_inp.plot(axs[2], ring_buffer.s_out, label='ring buffer output')
axs[2].eventplot(np.argwhere(data.reshape(-1) > 0).reshape(-1), color='r', label='data')
axs[2].legend()
axs[2].set_xticks(np.arange(0, n_steps, 1), minor=True)
axs[0].set_title('LIF voltage')
axs[1].set_title('LIF current')
axs[2].set_title('Input spikes')
plt.tight_layout()
plt.show()