nengo / nengo

A Python library for creating and simulating large-scale brain models
https://www.nengo.ai/nengo
Other
816 stars 175 forks source link

Handling tau_ref < dt for LIF (a.k.a. multiple spikes per time-step) #1427

Open arvoelke opened 6 years ago

arvoelke commented 6 years ago

As mentioned in #975, the current implementation of LIF assumes that dt <= tau_ref. Violating this assumption produces a drastic difference between the LIF and LIFRate model activities. With tau_ref=0 and all else set to default values (i.e., dt = 0.001), the LIF model fires on average 15 Hz slower thanLIFRate!

tau = 0.1
seed = 0

with nengo.Network() as model:
    u = nengo.Node(output=lambda t: 1)
    x_spike = nengo.Ensemble(500, 1, neuron_type=nengo.LIF(tau_ref=0), seed=seed)
    x_rate = nengo.Ensemble(500, 1, neuron_type=nengo.LIFRate(tau_ref=0), seed=seed)
    nengo.Connection(u, x_spike, synapse=None)
    nengo.Connection(u, x_rate, synapse=None)
    p_x_spike_a = nengo.Probe(x_spike.neurons, synapse=tau)
    p_x_rate_a = nengo.Probe(x_rate.neurons, synapse=tau)

with nengo.Simulator(model) as sim:
    sim.run(1.0)

plt.figure()
plt.title("LIF - LIFRate")
plt.plot(sim.trange(), np.mean(sim.data[p_x_spike_a] - sim.data[p_x_rate_a], axis=1))
plt.xlabel("Time (s)")
plt.ylabel("Mean Error [Hz]")
plt.show()

error

This occurs because the model stops processing the input after a spike is discovered. When tau_ref >= dt, this is okay, because the input won't have any more effect during that time-step. But, when tau_ref < dt, the input can still effect the voltage after the first spike. And, in general, it could even spike multiple times. This leads to a systematic under-counting of the number of spikes, which tends to bias the representation to under-approximate the desired function. This can have pretty serious consequences in the context of Principle 3. For instance, the 2D oscillator example that comes with Nengo will always flat-line quite rapidly (after fixing the transform and setting tau_ref=0):

oscillator

Note that increasing the number of neurons does not help this kind of problem!

My suggestion is to issue a warning if tau_ref < dt for the LIF model. It should be okay for LIFRate.

Another possibility is to extend the LIF model to allow multiple spikes within a given time-step. This would be consistent with the new SpikingRectifiedLinear model. I think it would be valuable for someone to try and benchmark (worried about the speed of an inner for-loop in Python)!

arvoelke commented 6 years ago

Actually, there is a way to do it without a for loop. I also mentioned this in #975 with my "O(1) additional processing" comment but failed to remember or elaborate. Basically, it should be possible to count the total number of spikes without explicitly looping through them all. The LIFRate equation should help here. The awkward complications come from properly handling everything before the first spike and after the last spike. I'll let this idea sit and if nobody bites I'll try it out.

arvoelke commented 6 years ago

I've gone ahead and implemented a version of the LIF model that can spike multiple times per time-step. I've also tested it across a generous range of overlapping tau_ref and dt values. In all cases, the total spike count differs from the expected LIFRate by an absolute error < 1, at every point in time. In other words, it never under- nor over-counts by a single spike -- even for dt as large as 100ms!

error_count

The model is quite a bit more complicated now. To help understand, it becomes necessary to refer to the following diagram:

multilif

I've also tried my best to document the code with a thorough amount of detail. Note that t_left(a) in the diagram refers to the initial value of t_left (before the update), and t_left(b) in the diagram refers to the updated value of t_left (after subtracting the time taken by all the extra spikes and an additional refractory period). The most complicated part comes from the fact that t_left(b) can become either positive or negative. If it's negative, then the refractory period extends into the next time-step (as is usually the case for LIF, and as depicted in the diagram). However, it can now also stay positive (not shown), in which case the membrane voltage must be updated once again (from 0, to whatever it becomes after the positive t_left elapses within the remainder of the time-step).

TODO before PR:

class MultiLIF(nengo.LIF):

    def step_math(self, dt, J, spiked, voltage, refractory_time):
        # reduce all refractory times by dt
        refractory_time -= dt

        # compute effective dt for each neuron, based on remaining time.
        # note that refractory times that have completed midway into this
        # timestep will be given a partial timestep, and moreover these will
        # be subtracted to zero at the next timestep (or reset by a spike)
        delta_t = (dt - refractory_time).clip(0, dt)

        # update voltage using zero-order hold (ZOH) discretized
        # lowpass filter <=> v(t) = v(0) + (J - v(0))*(1 - exp(-t/tau))
        # by assuming J is held constant across the time-step
        voltage += (voltage - J) * np.expm1(-delta_t / self.tau_rc)

        # determine which neurons spiked
        spiked_mask = voltage > 1

        # set v(0) = 1 and solve for t (using the ZOH equation above)
        # to get the time that has elapsed between the time of spike
        # up until the end of the time-step. in other words, this gives
        # the time remaining after the first spike, for each neuron
        J_spiked = J[spiked_mask]  # reusable term
        inv_J = 1 / (J_spiked - 1)  # reusable term
        t_left = -self.tau_rc * np.log1p((1 - voltage[spiked_mask]) * inv_J)

        # determine the interspike interval (1/rate) for the neurons that
        # spiked, based on LIFRate equation
        isi = self.tau_ref + self.tau_rc * np.log1p(inv_J)

        # compute the number of extra spikes that have also occurred
        # during this time-step, given the remaining time (not including the
        # first spike)
        n_extra = np.floor(t_left / isi)

        # update t_left to account for any extra spikes, and the final
        # refractory period. if it becomes negative this implies the refractory
        # period has extended into the next time-step (by -t_left seconds)
        # otherwise a positive value indicates there is still time remaining!
        t_left -= isi * n_extra + self.tau_ref

        # use any positive time remaining to update the voltage (from zero)
        # via the ZOH equation once again. note that if there is no positive time
        # remaining then the voltage will become 0
        voltage[spiked_mask] = -J_spiked * np.expm1(-t_left.clip(min=0) / self.tau_rc)

        # rectify negative voltages
        voltage[voltage < self.min_voltage] = self.min_voltage

        # additively output amplitude / dt for each spike, including the first
        spiked[:] = 0
        spiked[spiked_mask] = (1 + n_extra) * self.amplitude / dt

        # set refractory time to a full time-step (since dt will be
        # subtracted away on the next call to this function) plus
        # the time that tau_ref extended into the next time-step
        # (-t_left for negative t_left)
        refractory_time[spiked_mask] = dt - t_left.clip(max=0)

Testing:

def test(tau_ref, dt, seed, n_neurons=100, amplitude=1.0, t=1.0, verbose=False):
    max_rates = nengo.dists.Uniform(.1/tau_ref, 1/tau_ref) if tau_ref > 0 else nengo.params.Default

    with nengo.Network() as model:
        u = nengo.Node(output=lambda t: 1)

        x_spike = nengo.Ensemble(
            n_neurons, 1, seed=seed, max_rates=max_rates,
            neuron_type=MultiLIF(tau_ref=tau_ref, amplitude=amplitude))
        x_rate = nengo.Ensemble(
            n_neurons, 1, seed=seed, max_rates=max_rates,
            neuron_type=nengo.LIFRate(tau_ref=tau_ref, amplitude=amplitude))

        nengo.Connection(u, x_spike, synapse=None)
        nengo.Connection(u, x_rate, synapse=None)

        p_x_spike_a = nengo.Probe(x_spike.neurons, synapse=None)
        p_x_rate_a = nengo.Probe(x_rate.neurons, synapse=None)

    with nengo.Simulator(model, dt=dt, progress_bar=False) as sim:
        sim.run(t, progress_bar=False)

    error = (np.cumsum(sim.data[p_x_spike_a] - sim.data[p_x_rate_a], axis=0) *
             dt / amplitude)

    if verbose:
        plt.figure()
        plt.title("LIF - LIFRate")
        plt.plot(sim.trange(), error)
        plt.xlabel("Time (s)")
        plt.ylabel("Spike Count Error")
        plt.show()

    # Verify that the error between the spiking model and rate model
    # (in terms of the total number of spikes versus the expected amount)
    # never under-counts or over-counts by a single spike
    assert np.all(np.abs(error) < 1)

for tau_ref in (0, 0.0005, 0.001, 0.002, 0.01):
    for dt in (0.001, 0.002, 0.01, 0.1):
        test(tau_ref, dt, seed=0)
drasmuss commented 6 years ago

Yeah I'd be interested in seeing the benchmarking, to see how this compares to LIF/FastLIF. I think this is very cool and definitely something we want available either way, just curious how much it actually ends up costing in simulation speed.

arvoelke commented 6 years ago

Thanks! I will find time to look into benchmarking/etc later.

I just wanted to make a quick follow-up note that may be of theoretical interest to some. This also leads to an interesting re-characterization of LIFRate as the limit of LIF as dt -> infinity. This is a bit backwards from our usual way of thinking about it. But this makes perfect sense as the model produces the number of spikes that occur during a time window, divided by the length of that window (dt), which is exactly the same as the rate for that input as dt -> infinity.

A weird way to demonstrate this insight, is by sweeping the input across [-1, +1], and observing that the spiking output of LIF itself provides the tuning curves (no filtering required)! For the graph below I made dt=0.5 (half a second). Unfortunately larger values of dt currently lead to numerical issues in the step that determines t_left(a). A rewrite is needed to support even larger dt.

tuning

def spike_tuning(seed, n_neurons=20, n_samples=1000, dt=0.5):
    with nengo.Network() as model:
        u = nengo.Node(output=lambda t: 2*t/dt/n_samples-1)
        x = nengo.Ensemble(n_neurons, 1, seed=seed, neuron_type=MultiLIF())
        nengo.Connection(u, x, synapse=None)
        p = nengo.Probe(x.neurons, synapse=None)

    with nengo.Simulator(model, dt=dt, progress_bar=False) as sim:
        sim.run(n_samples*dt, progress_bar=False)

    fig, ax = plt.subplots(1, 2, figsize=(14, 7))
    ax[0].set_title("LIF()")
    ax[0].plot(sim.trange(), sim.data[p])
    ax[0].set_xlabel("Time (s)")
    ax[0].set_ylabel("Spiking Output")
    ax[1].set_title("Tuning Curves")
    eval_points = np.linspace(-1, 1, n_samples)
    ax[1].plot(eval_points, nengo.builder.ensemble.get_activities(
        sim.data[x], x, eval_points[:, None]))
    ax[1].set_xlabel("x")
    ax[1].set_ylabel("Firing Rate (Hz)")
    fig.show()

In other words, for large enough dt there isn't much of a difference between using the LIF model and the LIFRate model. This however means that caution should be taken when interpreting simulation results using this model with really large dt. It is a bit like having a perfect integrator across each entire time-step. Or, more accurately, it is like delaying all of the spikes to occur at the very end of the window, stacking them all linearly on top of one another, and making them each a rectangular pulse with width dt. Although this more accurately reflects the LIFRate model, it does not reflect what would happen for the same network with smaller time-steps, wherein the timing of the spikes across the interval can matter.

My current intuition is that for dt on the order of the smallest synaptic time-constant in the network, the accuracy of the network could actually improve, albeit in a manner that could be considered "cheating" since we're introducing an implicit integrator across a longer time-scale than the synapse itself. However, if the time-constant of each synapse is large enough relative to dt, the behaviour of the network should be pretty similar.

arvoelke commented 5 years ago

I was doing some miscellaneous Nengo benchmarking with different dt and noticed a significant change in my results for dt = 1 / 499 versus dt <= 1 / 500 due to this old issue, since 1 / tau_ref == 500. It would be helpful to raise an error or warning for the case when the dt is too large for the particular neuron model to keep up with (i.e., dt > tau_ref in the current LIF implementation).