emer / axon

Axon is a spiking, biologically-based neural model driven by predictive error-driven learning, for systems-level models of the brain
BSD 3-Clause "New" or "Revised" License
19 stars 7 forks source link

issues for integrating synaptic inputs: max driver, per synapse.. #28

Closed rcoreilly closed 2 years ago

rcoreilly commented 2 years ago

In the code up to this point: geSyn += geRaw - decay * geSyn -- new raw input adds directly, and it then decays.

This produces unbounded levels of ge, which is not realistic: there is a max, which is all channels open.

Ideally, you would apply the following update rule per each synaptic input syn:

syn += (1-syn)*x - decay * syn

where x is the sending spike signal (0 or 1). Thus, you can only get so much out of each individual synapse, which is where the actual constraint lies.

However, we don't keep track of individual synaptic conductances, and instead aggregate all the punctate spikes when they occur (geRaw is sum of all current spikes), and then decay the aggregate value -- much more efficient.

For Ge (AMPA channels), this is not such a major issue, because their decay constant is around 5 msec, and the refractory period for a sending neuron is 3 ms, and rarely do they fire above about 100 hz = 10 ms. Nevertheless, we can efficiently add a presynaptic factor that tracks this decay per-sender, and send that instead of the raw 1 or 0 (i.e., it will be less than 1 if not yet fully decayed).

For NMDA channels, where the decay constant is upward of 100 ms, this is in fact a major issue. We must maintain a sending decay factor (which is needed for learning in any case) and increment NMDA separate from Ge in the spike. Might as well get the Ge right as well -- can use that for any other modulation as a function of sending spiking activity (e.g., facilitation, depression). And Gi..

rcoreilly commented 2 years ago

In practice, this introduces a ton of noise. It is similar to synaptic failures, actually, because when a sending neuron spikes, its actual impact is variable depending on history. Interestingly, the SnmdaI inhibitory component that captures the allosteric NMDA dynamics is also noise inducing.

I verified in a simple ra25 test that doubling the size of the hidden layers overcomes the extra noise, and maybe in larger nets it will be a bonus, but clearly we need a flag for smaller nets. I added SeiDeplete and SnmdaDepete flags in DendParams for this -- can experiment in larger models.

A logical argument for not doing deplete in smaller models is that each neuron is a stand-in for a larger population of similar such neurons, and in this case, the depletion is spread more evenly among the cohort, and is thus not so impactful.

rcoreilly commented 2 years ago

finally verified in lvis big model that indeed it is much more resistant to these depletion effects, but they are nevertheless still impairing instead of facilitating performance, so no point in keeping them in there at this point. NMDA depletion is less impactful than Ge,i depletion.

Here's the relevant code for future reference.

prjn.go:

// Required an extra GnmdaBuf for separately depleting nmda

    GBuf     []float32   `desc:"Ge or Gi conductance ring buffer for each neuron * Gidx.Len, accessed through Gidx, and length Gidx.Len in size per neuron -- weights are added with conductance delay offsets."`
    GnmdaBuf []float32   `desc:"Gnmda NMDA conductance ring buffer for each neuron * Gidx.Len, accessed through Gidx, and length Gidx.Len in size per neuron -- weights are added with conductance delay offsets."`

// SendESpike sends an excitatory spike from sending neuron index si,
// to add to buffer on receivers.
// Sends proportion of synaptic channels that remain open as function
// of time since last spike, for Ge and Gnmda channels.
func (pj *Prjn) SendESpike(si int, sge, snmda float32) {
    sc := pj.GScale.Scale
    sge *= sc
    snmda *= sc
    del := pj.Com.Delay
    sz := del + 1
    di := pj.Gidx.Idx(del) // index in buffer to put new values -- end of line
    nc := pj.SConN[si]
    st := pj.SConIdxSt[si]
    syns := pj.Syns[st : st+nc]
    scons := pj.SConIdx[st : st+nc]
    for ci := range syns {
        ri := scons[ci]
        pj.GBuf[int(ri)*sz+di] += sge * syns[ci].Wt
        pj.GnmdaBuf[int(ri)*sz+di] += snmda * syns[ci].Wt
    }
}
layer.go:

// SendSpike sends spike to receivers -- last step in Cycle, integrated
// the next time around.
func (ly *Layer) SendSpike(ltime *Time) {
    for ni := range ly.Neurons {
        nrn := &ly.Neurons[ni]
        if nrn.IsOff() || nrn.Spike == 0 {
            ly.Act.SenderGDecay(nrn)
            continue
        }
        for _, sp := range ly.SndPrjns {
            if sp.IsOff() {
                continue
            }
            if sp.Type() == emer.Inhib {
                sp.(AxonPrjn).SendISpike(ni, nrn.Si)
            } else {
                if ly.Act.Dend.SnmdaDeplete {
                    sp.(AxonPrjn).SendESpike(ni, nrn.Se, nrn.Snmda*(1.0-nrn.SnmdaI))
                } else {
                    sp.(AxonPrjn).SendESpike(ni, nrn.Se, nrn.Snmda) // no I either
                }
            }
        }
        ly.Act.SenderGSpiked(nrn)
    }
}
act.go:

// SenderGDecay updates Se, Si, Snmda when the neuron has not
// spiked this time around -- decays the sender channels back open
// in effect.
func (ac *ActParams) SenderGDecay(nrn *Neuron) {
    nrn.Se += (1 - nrn.Se) * ac.Dt.GeDt
    nrn.Si += (1 - nrn.Si) * ac.Dt.GiDt
    nrn.Snmda += (1 - nrn.Snmda) * ac.NMDA.Dt
}

// SenderGSpiked sets Se, Si, Snmda to 0 when the neuron spikes, if doing depletion
func (ac *ActParams) SenderGSpiked(nrn *Neuron) {
    // note that timing for S* factors is prior to communication delay
    // but their effect will be delayed so this is appropriate
    if ac.Dend.SeiDeplete {
        nrn.Se = 0
        nrn.Si = 0
    }
    if ac.Dend.SnmdaDeplete {
        nrn.Snmda = 0
    }
}

// DendParams are the parameters for updating dendrite-specific dynamics
type DendParams struct {
    GbarExp      float32 `def:"0.2,0.5" desc:"dendrite-specific strength multiplier of the exponential spiking drive on Vm -- e.g., .5 makes it half as strong as at the soma (which uses Gbar.L as a strength multiplier per the AdEx standard model)"`
    GbarR        float32 `def:"3,6" desc:"dendrite-specific conductance of Kdr delayed rectifier currents, used to reset membrane potential for dendrite -- applied for Tr msec"`
    SeiDeplete   bool    `desc:"When a sending spike occurs, deplete the Se and Si factors to track availability of each synapse's channels based on time since last spiking.  This introduces noise, similar to synaptic failure -- works well for larger nets but is detrimental to small ones."`
    SnmdaDeplete bool    `desc:"When a sending spike occurs, deplete the Snmda factor to track availability of each synapse's channels based on time since last spiking.  This introduces significant noise in NMDA dynamics due to long time constant, similar to synaptic failure -- suitable for larger nets but likely detrimental to small ones."`
}
rcoreilly commented 2 years ago

ps. last release with this in it, for future reference, is v1.4.13, b7d51c2553440ac1fe7340c95a48fd64ed2b9823: https://github.com/emer/axon/commit/b7d51c2553440ac1fe7340c95a48fd64ed2b9823