nengo / nengo-extras

Extra utilities and add-ons for Nengo
https://www.nengo.ai/nengo-extras
Other
5 stars 8 forks source link

Tools for saving and loading decoders #35

Open tcstewar opened 7 years ago

tcstewar commented 7 years ago

In https://github.com/nengo/nengo/issues/649 and https://github.com/nengo/nengo/issues/608 (and maybe other places), there have been requests for ways to save and load decoders. There are two common use cases:

1) you've computed decoders in some weird way and would like to use those instead

2) you're using a learning rule and would like to start the learning rule off with the decoders from the end point of a previous run.

Use case 1 can usually be handled by doing nengo.Connection(a.neurons, b, transform=decoders). However, this doesn't work everywhere -- for example, nengo_spinnaker doesn't allow Connections from .neurons. Use case 2 is even more problematic -- there's currently no way to seed the start of a learning rule with anything other than the result of some Solver.

The workaround that a few people have implemented is to define a Solver that just returns whatever matrix you've explicitly told it to:

class Explicit(nengo.solvers.Solver):
    def __init__(self, value, weights=False):
        super(Explicit, self).__init__(weights=weights)
        self.value = value

    def __call__(self, A, Y, rng=None, E=None):
        return self.value, {}

This is quite handy for use case 1, and should at least be put into nengo_extras (and maybe even into core nengo).

This also forms the good basis for a solution to use case 2, and indeed it's not too bad to explicitly implement use case 2 with this Explicit solver:

model = nengo.Network(seed=1)
with model:
    stim = nengo.Node(lambda t: np.sin(t*np.pi*2))
    a = nengo.Ensemble(100, 1)
    b = nengo.Ensemble(50, 1)

    filename = 't2.npy'
    try:
        value = np.load(filename)
    except IOError:
        value = np.zeros((100, 1))

    c = nengo.Connection(a, b, learning_rule_type=nengo.PES(), solver=Explicit(value))
    p_c = nengo.Probe(c, 'weights', sample_every=1.0)

    nengo.Connection(stim, c.learning_rule, transform=-1)
    nengo.Connection(b, c.learning_rule, transform=1)
    nengo.Connection(stim, a)    

sim = nengo.Simulator(model)
sim.run(3)

np.save(filename, sim.data[p_c][-1].T)

But that's rather ugly. You have to open the file, handle the initial case when the file doesn't exist, create a probe, and save the data at the end (remembering to take the transpose). So let's make a helper for this:

# loads a decoder from a file, defaulting to zero if it doesn't exist
class LoadFrom(nengo.solvers.Solver):
    def __init__(self, filename, weights=False):
        super(LoadFrom, self).__init__(weights=weights)
        self.filename = filename

    def __call__(self, A, Y, rng=None, E=None):
        if self.weights:
            shape = (A.shape[1], E.shape[1])
        else:
            shape = (A.shape[1], Y.shape[1])

        try:
            value = np.load(self.filename)
            assert value.shape == shape
        except IOError:
            value = np.zeros(shape)
        return value, {}

# helper to create the LoadFrom solver and the needed probe and do the saving
class WeightSaver(object):
    def __init__(self, connection, filename, sample_every=1.0, weights=False):
        assert isinstance(connection.pre, nengo.Ensemble)
        if not filename.endswith('.npy'):
            filename = filename + '.npy'
        self.filename = filename
        connection.solver = LoadFrom(self.filename, weights=weights)
        self.probe = nengo.Probe(connection, 'weights', sample_every=sample_every)
        self.connection = connection
    def save(self, sim):
        np.save(self.filename, sim.data[self.probe][-1].T)

In order to use this, we can do something like this, only needing to add 2 lines to the whole thing:

model = nengo.Network(seed=1)
with model:
    stim = nengo.Node(lambda t: np.sin(t*np.pi*2))
    a = nengo.Ensemble(100, 1)
    b = nengo.Ensemble(50, 1)
    c = nengo.Connection(a, b, learning_rule_type=nengo.PES())
    nengo.Connection(stim, c.learning_rule, transform=-1)
    nengo.Connection(b, c.learning_rule, transform=1)
    nengo.Connection(stim, a)    

    ws = WeightSaver(c, 'my_weights')   # add this line

sim = nengo.Simulator(model)
sim.run(3)
ws.save(sim)   # and add this line when you're done

One interesting feature of this approach is that it loads the file at build time. I think that's what we usually want....

tcstewar commented 7 years ago

Oh, and one of the things I'm thinking about proposing for nengo_gui is some function hooks to give you access to the simulator object, so you could do something like this at the end of your code in nengo_gui:

def on_sim_done(sim): 
    # this will get triggered when the simulation is over
    weight_saver.save(sim)
hunse commented 7 years ago

Use case 2 is even more problematic -- there's currently no way to seed the start of a learning rule with anything other than the result of some Solver.

Sure you can. You just make a connection from the neurons to the post object as in case 1. E.g.

import matplotlib.pyplot as plt
import numpy as np

import nengo

n_neurons = 200
initial = np.random.uniform(-0.0001, 0.0001, size=(1, n_neurons))

with nengo.Network() as model:
    x = nengo.Node(nengo.processes.WhiteSignal(period=100, high=5))
    ystar = nengo.Node(lambda t, x: x**2, size_in=1)
    nengo.Connection(x, ystar)

    a = nengo.Ensemble(n_neurons, 1)
    y = nengo.Node(size_in=1)

    error = nengo.Node(size_in=1)
    nengo.Connection(y, error)
    nengo.Connection(ystar, error, transform=-1)

    conn = nengo.Connection(a.neurons, y, transform=initial,
                            learning_rule_type=nengo.PES(1e-3))
    nengo.Connection(error, conn.learning_rule)

    xp = nengo.Probe(x, synapse=0.01)
    ystarp = nengo.Probe(ystar, synapse=0.01)
    yp = nengo.Probe(y, synapse=0.01)

with nengo.Simulator(model) as sim:
    sim.run(15.0)

plt.plot(sim.trange(), sim.data[xp])
plt.plot(sim.trange(), sim.data[ystarp])
plt.plot(sim.trange(), sim.data[yp])
plt.show()
tcstewar commented 7 years ago

Sure you can. You just make a connection from the neurons to the post object as in case 1

Huh. I had no idea you could do that. That makes complete sense in hind-sight....