nengo / nengo-extras

Extra utilities and add-ons for Nengo
https://www.nengo.ai/nengo-extras
Other
5 stars 8 forks source link

Add QIF neuron #92

Open tcstewar opened 8 years ago

tcstewar commented 8 years ago

This has come up a couple times, so I thought I'd post it here. I'm not sure if QIF should just be added to Nengo itself, or if it should be used as an example of how to define your own neuron model.

tcstewar commented 6 years ago

A student just asked me about QIF neurons, so here's a quick implementation. We might want to include this somewhere as an example.

class QIF(nengo.neurons.NeuronType):
    def __init__(self, threshold=1, reset=0):
        super(QIF, self).__init__()
        self.threshold = threshold
        self.reset = reset

    def step_math(self, dt, J, spiked, voltage):
        # the actual neuron model
        voltage += voltage**2 + np.maximum(J,0) * dt
        spikes = voltage > self.threshold
        spiked[:] = spikes
        voltage[spikes] = self.reset

    def rates(self, x, gain, bias):
        # run the neurons for a bit to estimate firing rates
        J = self.current(x, gain, bias)
        voltage = np.zeros_like(J)
        return nengo.neurons.settled_firingrate(self.step_math, 
                                  J, [voltage],
                                  settle_time=0.001, sim_time=1.0)        

# connect the neuron model to the nengo builder system
@nengo.builder.Builder.register(QIF)
def build_qif(model, qif, neurons):
    model.sig[neurons]['voltage'] = nengo.builder.Signal(
        np.zeros(neurons.size_in), name="%s.voltage" % neurons)
    model.add_op(nengo.builder.neurons.SimNeurons(
        neurons=qif,
        J=model.sig[neurons]['in'],
        output=model.sig[neurons]['out'],
        states=[model.sig[neurons]['voltage']]))

Here's an example of using it

model = nengo.Network()
with model:
    a = nengo.Ensemble(n_neurons=50, dimensions=1,
                       gain=nengo.dists.Uniform(1,1),
                       bias=nengo.dists.Uniform(-0.5, 1.5),
                       neuron_type=QIF(threshold=1, reset=0))
    stim = nengo.Node(lambda t: np.sin(2*np.pi*t))
    nengo.Connection(stim, a)

    p = nengo.Probe(a, synapse=0.03)
    p_spikes = nengo.Probe(a.neurons)
sim = nengo.Simulator(model)
sim.run(2)

and the resulting decodes and spikes:

image

image

tbekolay commented 6 years ago

Seems like a good candidate for nengo_extras, unless we see this as important enough for Nengo core.

celiasmith commented 6 years ago

it is a pretty common neuron, so i could see an argument for putting it in core.

arvoelke commented 6 years ago

I have a few things to note, while making reference to the paper: Neural dynamics, bifurcations and firing rates in a quadratic integrate-and-fire model with a recovery variable. I: deterministic behavior (Shlizerman and Holmes, 2011).

class QIF(nengo.neurons.NeuronType):

    probeable = ('spikes', 'voltage')

    def __init__(self, threshold=1, reset=-0.1):
        super(QIF, self).__init__()
        self.threshold = threshold
        self.reset = reset

    def step_math(self, dt, J, spiked, voltage):
        voltage += (voltage**2 + J) * dt
        spikes = voltage > self.threshold
        spiked[:] = spikes / dt
        voltage[spikes] = self.reset
        voltage[voltage < self.reset] = self.reset

    def rates(self, x, gain, bias):
        J = self.current(x, gain, bias)
        r = np.zeros_like(J)
        Jmask = J > 0
        sqrtJ = np.sqrt(J[Jmask])
        r[Jmask] = sqrtJ / (np.arctan(self.threshold / sqrtJ) -
                            np.arctan(self.reset / sqrtJ))
        return r

# connect the neuron model to the nengo builder system
@nengo.builder.Builder.register(QIF)
def build_qif(model, qif, neurons):
    # initialize the voltage vector to the reset value
    model.sig[neurons]['voltage'] = nengo.builder.Signal(
        qif.reset * np.ones(neurons.size_in), name="%s.voltage" % neurons)
    model.add_op(nengo.builder.neurons.SimNeurons(
        neurons=qif,
        J=model.sig[neurons]['in'],
        output=model.sig[neurons]['out'],
        states=[model.sig[neurons]['voltage']]))
tau_probe = 0.03

with nengo.Network() as model:
    stim = nengo.Node(lambda t: np.sin(2*np.pi*t))
    a = nengo.Ensemble(n_neurons=50, dimensions=1, neuron_type=QIF(),
                       max_rates=nengo.dists.Uniform(50, 100))
    nengo.Connection(stim, a, synapse=None)

    p = nengo.Probe(a, synapse=tau_probe)
    p_ideal = nengo.Probe(stim, synapse=tau_probe)
    p_spikes = nengo.Probe(a.neurons, 'spikes')
    p_voltage = nengo.Probe(a.neurons, 'voltage')

with nengo.Simulator(model, dt=1e-3) as sim:
    sim.run(2)

plt.figure()
plt.title("QIF() Communication Channel")
plt.plot(sim.trange(), sim.data[p], label="Actual")
plt.plot(sim.trange(), sim.data[p_ideal], linestyle='--', label="Ideal")
plt.xlabel("Time (s)")
plt.ylabel("Decoded")
plt.show()

simulation

However, in the end, without modelling any refractory period or recovery dynamics, we basically end up with linear tuning curves (a spiking ReLU model). And so either refractory/recovery dynamics should be included, or the default threshold/reset need to be modified to obtain less linear curves.

u = np.linspace(-1, 1)
plt.figure()
plt.title("QIF() Tuning Curves")
plt.plot(u, nengo.builder.ensemble.get_activities(sim.data[a], a, u[:, None]))
plt.xlabel("x")
plt.ylabel("Firing Rate (Hz)")
plt.show()

curves