brian-team / brian2

Brian is a free, open source simulator for spiking neural networks.
http://briansimulator.org
Other
906 stars 217 forks source link

Implement surrogate gradient descent method #1207

Open thesamovar opened 4 years ago

thesamovar commented 4 years ago

We should implement the PyTorch backend (#1014) and an option for full end-to-end differentiable operation so that Neftci et al. (2019) method can be used for training spiking neural networks with surrogate gradient descent. Some changes would be relatively simple (reset, threshold), while dealing with synaptic propagation is likely to be complicated, not least because it may require that the internal data structure be changed to a dense matrix and we might not be able to use SpikeQueue.

Would be good to get comments from and discuss with @fzenke when we get around to this. See also his SPyTorch tutorial.

Snow-Crash commented 3 years ago

It seems that there will be massive works to do this in brian2. You need to implement auto differentiation. Perhaps it is easier to provide an interface/backend to pytorch, which can pass network topology and parameters, such that pytorch can build a same network. Then the backprop can be very easily handled by pytorch.

Implement gradient surrogate in quite easy in pytorch. I have implement this on lif neuron with different synapse types in pytorch. One only has to define the forward function of neuron/synapse and simply overwrite the backward function of threshold operation. The forward function is the step function, which calculates the ode by 1 unit time forward. Since the pytorch syntax is similar to numpy, and brian already supports numpy. Translating network to pytorch seems easier than implementing auto differentiation in brian.

thesamovar commented 3 years ago

That's the idea! It's relatively easy to do for any given model, but there are some tricky aspects to think about to do it for an arbitrary model.

fzenke commented 3 years ago

It would be great to have surrogate gradient support in Brian2. One thing that we have been grappling with in the past are sparse connectivity matrices since auto-grad support for them is relatively poor in both PyTorch and Tensorflow (at least the last time I checked). How are connectivity matrices currently implemented in Brian2 and do you think this may cause issues? I'd be keen to learn about any solutions.

thesamovar commented 3 years ago

They're sparse matrices in a custom format, but easy to dump out a CSR matrix from the way we store them. However, what I was thinking was converting into dense format as an intermediate step.

arashgmn commented 1 year ago

It's very interesting to have this feature. Is there any update on this?

femtomc commented 12 months ago

@thesamovar any update on this? I have interested collaborators who would like to use Brian2, but want some AD functionality (e.g. if possible, a JAX backend).

If you point me at some breadcrumbs for "backends" (I'm not sure if Brian2 uses things like this, or how it works) -- I'm a reasonably competent JAX programmer, I could try and hack something up.

thesamovar commented 11 months ago

Nothing on the timeline right now I'm afraid, and I don't think it's that easy. You could take a look at the package brian2genn which runs Brian via the GeNN GPU simulator to give an idea of what would be involved?