ctn-waterloo / modelling_ideas

Ideas for models that could be made with Nengo if anyone has time
9 stars 1 forks source link

Expressing backprop as an encoder learning rule? #88

Open tcstewar opened 6 years ago

tcstewar commented 6 years ago

Would it be possible to implement an encoder learning rule in Nengo that uses the backprop learning rule? Or even the feedback alignment learning rule (which is basically just backprop but with a random weight matrix used rather than the transpose of the forward weights)?

What I'm picturing is something like this:

model = nengo.Network()
with model:
    input = nengo.Node(None, size_in=2)  # this could also be an Ensemble
    hidden = nengo.Ensemble(n_neurons=100, dimensions=2)
    output = nengo.Node(None, size_in=1)  # this could also be an Ensemble

    layer2 = nengo.Connection(hidden, output,
                function=lambda x: 0,

    layer1 = nengo.Connection(input, hidden,

    error = nengo.Node(None, size_in=1)
    nengo.Connection(error, layer1.learning_rule)
    nengo.Connection(error, layer2.learning_rule)

For FeedbackAlignment, it might be something like this:

    layer1 = nengo.Connection(input, hidden,

My understanding is that BackProp would need access to layer2, since it needs the weights for that connection, but FeedbackAlignment might be able to get away with just access to the output. (actually, do we even need that? Or do we just need access to the error signal?)

One complication is that both of these learning rules need to take the derivative of the hidden layer nonlinearity. I don't think they need to do anything with the output layer's nonlinearity (indeed, in this case there is no nonlinearity there).

In any case, I was just thinking that something like this might be useful, and even somewhat biologically legit, given things like the deep-learning-in-pyramidal-neurons talk given recently based on this paper https://arxiv.org/pdf/1610.00161.pdf

tcstewar commented 6 years ago

(one interesting complication here is that we also need to be adjusting the bias, not just the encoders)

tcstewar commented 6 years ago

Here's a not-quite-working implementation of the FeedbackAlignment part of this, but using a Node instead of a learning rule:

class FeedbackAlignment(nengo.Node):
    def __init__(self, ens, error_size, learning_rate=1e-4, seed=10):
        super(FeedbackAlignment, self).__init__(self.update, 
                                                size_in=ens.dimensions + error_size,
        rng = np.random.RandomState(seed=seed)
        self.encoders = ens.encoders.sample(ens.n_neurons, ens.dimensions, rng=rng)
        self.feedback = ens.encoders.sample(ens.n_neurons, error_size, rng=rng).T
        self.input_dims = ens.dimensions
        self.learning_rate = learning_rate

    def update(self, t, x):
        x, error = x[:self.input_dims], x[self.input_dims:]

        a = np.dot(self.encoders, x)

        nl_a = np.where(a>0, a, 0)  # this is wrong!  doesn't include gain or bias!  Should get this right from the ensemble instead
        da = np.where(a>0, 1, 0)   # derivative of ReLU

        dw = np.dot(nl_a, da * np.dot(error, self.feedback))

        self.encoders -= dw * self.learning_rate

        return a

model = nengo.Network(seed=0)
with model:
    scale = 2.0
    input = nengo.Node(lambda t: (np.sin(t*2*np.pi/scale), np.sin(t*2*np.pi*0.7/scale)))

    hidden = nengo.Ensemble(n_neurons=10, dimensions=2,

    output = nengo.Node(None, size_in=1)
    correct = nengo.Node(None, size_in=1)
    error = nengo.Node(None, size_in=1)

    nengo.Connection(input, correct, function=lambda x: x[0]*x[1], synapse=None)
    nengo.Connection(output, error, synapse=None)
    nengo.Connection(correct, error, synapse=None, transform=-1)

    p_error = nengo.Probe(error)
    p_correct = nengo.Probe(correct)
    p_output = nengo.Probe(output)

    #nengo.Connection(input, hidden, synapse=None)
    fa = FeedbackAlignment(hidden, error_size=1)
    nengo.Connection(input, fa[:2], synapse=None)
    nengo.Connection(error, fa[2:], synapse=0)
    nengo.Connection(fa, hidden.neurons, synapse=None)

    layer2 = nengo.Connection(hidden, output, synapse=None, function=lambda x: 0,
    nengo.Connection(error, layer2.learning_rule, synapse=0)

One interesting thing I like about it is using the normal ens.encoders both to initialize the encoders and the fixed feedback weights.

Seanny123 commented 6 years ago

Didn't @hunse work on this to a certain degree?