Open tcstewar opened 8 years ago
A student just asked me about QIF neurons, so here's a quick implementation. We might want to include this somewhere as an example.
class QIF(nengo.neurons.NeuronType):
def __init__(self, threshold=1, reset=0):
super(QIF, self).__init__()
self.threshold = threshold
self.reset = reset
def step_math(self, dt, J, spiked, voltage):
# the actual neuron model
voltage += voltage**2 + np.maximum(J,0) * dt
spikes = voltage > self.threshold
spiked[:] = spikes
voltage[spikes] = self.reset
def rates(self, x, gain, bias):
# run the neurons for a bit to estimate firing rates
J = self.current(x, gain, bias)
voltage = np.zeros_like(J)
return nengo.neurons.settled_firingrate(self.step_math,
J, [voltage],
settle_time=0.001, sim_time=1.0)
# connect the neuron model to the nengo builder system
@nengo.builder.Builder.register(QIF)
def build_qif(model, qif, neurons):
model.sig[neurons]['voltage'] = nengo.builder.Signal(
np.zeros(neurons.size_in), name="%s.voltage" % neurons)
model.add_op(nengo.builder.neurons.SimNeurons(
neurons=qif,
J=model.sig[neurons]['in'],
output=model.sig[neurons]['out'],
states=[model.sig[neurons]['voltage']]))
Here's an example of using it
model = nengo.Network()
with model:
a = nengo.Ensemble(n_neurons=50, dimensions=1,
gain=nengo.dists.Uniform(1,1),
bias=nengo.dists.Uniform(-0.5, 1.5),
neuron_type=QIF(threshold=1, reset=0))
stim = nengo.Node(lambda t: np.sin(2*np.pi*t))
nengo.Connection(stim, a)
p = nengo.Probe(a, synapse=0.03)
p_spikes = nengo.Probe(a.neurons)
sim = nengo.Simulator(model)
sim.run(2)
and the resulting decodes and spikes:
Seems like a good candidate for nengo_extras
, unless we see this as important enough for Nengo core.
it is a pretty common neuron, so i could see an argument for putting it in core.
I have a few things to note, while making reference to the paper: Neural dynamics, bifurcations and firing rates in a quadratic integrate-and-fire model with a recovery variable. I: deterministic behavior (Shlizerman and Holmes, 2011).
-0.1
to follow Figure 1.v = +/- sqrt(-J)
) . This means the voltage will either stabilize at -sqrt(-J)
, or it may spike once more due to the other fixed point being unstable, before finally stabilizing at -sqrt(-J)
. I also include voltage[voltage < self.reset] = self.reset
. See Figure 2 for a nice summary.v**2
term is a part of the derivative. That is, the delta should be (voltage**2 + J) * dt
instead of voltage**2 + (J * dt)
.J > 0
(which is fine under constant-input assumption, due to the above bifurcation analysis).dt
should be dropped by at least a factor of 10
, otherwise spikes are under-counted. My example below reduces the firing rates in order to use the same dt
.voltage
can be made probeable.1/dt
to be consistent with all other neuron models (this keeps the area of the output pulse constant under varying dt
).amplitude
parameter should also be included if this is to be added to Nengo.class QIF(nengo.neurons.NeuronType):
probeable = ('spikes', 'voltage')
def __init__(self, threshold=1, reset=-0.1):
super(QIF, self).__init__()
self.threshold = threshold
self.reset = reset
def step_math(self, dt, J, spiked, voltage):
voltage += (voltage**2 + J) * dt
spikes = voltage > self.threshold
spiked[:] = spikes / dt
voltage[spikes] = self.reset
voltage[voltage < self.reset] = self.reset
def rates(self, x, gain, bias):
J = self.current(x, gain, bias)
r = np.zeros_like(J)
Jmask = J > 0
sqrtJ = np.sqrt(J[Jmask])
r[Jmask] = sqrtJ / (np.arctan(self.threshold / sqrtJ) -
np.arctan(self.reset / sqrtJ))
return r
# connect the neuron model to the nengo builder system
@nengo.builder.Builder.register(QIF)
def build_qif(model, qif, neurons):
# initialize the voltage vector to the reset value
model.sig[neurons]['voltage'] = nengo.builder.Signal(
qif.reset * np.ones(neurons.size_in), name="%s.voltage" % neurons)
model.add_op(nengo.builder.neurons.SimNeurons(
neurons=qif,
J=model.sig[neurons]['in'],
output=model.sig[neurons]['out'],
states=[model.sig[neurons]['voltage']]))
tau_probe = 0.03
with nengo.Network() as model:
stim = nengo.Node(lambda t: np.sin(2*np.pi*t))
a = nengo.Ensemble(n_neurons=50, dimensions=1, neuron_type=QIF(),
max_rates=nengo.dists.Uniform(50, 100))
nengo.Connection(stim, a, synapse=None)
p = nengo.Probe(a, synapse=tau_probe)
p_ideal = nengo.Probe(stim, synapse=tau_probe)
p_spikes = nengo.Probe(a.neurons, 'spikes')
p_voltage = nengo.Probe(a.neurons, 'voltage')
with nengo.Simulator(model, dt=1e-3) as sim:
sim.run(2)
plt.figure()
plt.title("QIF() Communication Channel")
plt.plot(sim.trange(), sim.data[p], label="Actual")
plt.plot(sim.trange(), sim.data[p_ideal], linestyle='--', label="Ideal")
plt.xlabel("Time (s)")
plt.ylabel("Decoded")
plt.show()
However, in the end, without modelling any refractory period or recovery dynamics, we basically end up with linear tuning curves (a spiking ReLU model). And so either refractory/recovery dynamics should be included, or the default threshold/reset need to be modified to obtain less linear curves.
u = np.linspace(-1, 1)
plt.figure()
plt.title("QIF() Tuning Curves")
plt.plot(u, nengo.builder.ensemble.get_activities(sim.data[a], a, u[:, None]))
plt.xlabel("x")
plt.ylabel("Firing Rate (Hz)")
plt.show()
This has come up a couple times, so I thought I'd post it here. I'm not sure if QIF should just be added to Nengo itself, or if it should be used as an example of how to define your own neuron model.