SpikingNetwork / TrainSpikingNet.jl

train a spiking recurrent neural network
BSD 3-Clause "New" or "Revised" License
14 stars 4 forks source link

Can't generate Potjans connectome that is interopable with TrainingSpikeNet.jl for plastic weights #1

Closed russelljjarvis closed 1 year ago

russelljjarvis commented 1 year ago

Hi there @bjarthur

Actually I lied in this issue title. I can now get the Potjans connectome to work as a static matrix, for kind=:init, if I hacked params.jl and init.jl, I will send details and links to hacked code soon.

To reproduce this error:

run bash workflow.sh

In the base directory at: https://github.com/JuliaWSU/TrainSpikingNet.jl

Relevant contents of param.jl https://github.com/JuliaWSU/TrainSpikingNet.jl/blob/master/src/param.jl

genPlasticWeights_file = "genWeightsPotjans.jl"
genStaticWeights_file = "genWeightsPotjans.jl"

genPlasticWeights_args = Dict(:Ncells => Ncells, :frac => frac, :Ne => Ne, :L => L, :Lexc => Lexc, :Linh => Linh, :Lffwd => Lffwd,
                              :wpee => 2.0 * taue * g / wpscale,
                              :wpie => 2.0 * taue * g / wpscale,
                              :wpei => -2.0 * taue * g / wpscale,
                              :wpii => -2.0 * taue * g / wpscale,
                              :wpffwd => 0)

genStaticWeights_args = Dict(:Ncells => Ncells, :Ne => Ne,
                             :pree => 0.1, :prie => 0.1, :prei => 0.1, :prii => 0.1)
./tsn.sh init ${PWD}/src

Relevant contents of https://github.com/JuliaWSU/TrainSpikingNet.jl/blob/master/src/genWeightsPotjans.jl

using Revise
using SparseArrays
using ProgressMeter

function potjans_params()
    conn_probs = [[0.1009,  0.1689, 0.0437, 0.0818, 0.0323, 0.,     0.0076, 0.    ],
                [0.1346,   0.1371, 0.0316, 0.0515, 0.0755, 0.,     0.0042, 0.    ],
                [0.0077,   0.0059, 0.0497, 0.135,  0.0067, 0.0003, 0.0453, 0.    ],
                [0.0691,   0.0029, 0.0794, 0.1597, 0.0033, 0.,     0.1057, 0.    ],
                [0.1004,   0.0622, 0.0505, 0.0057, 0.0831, 0.3726, 0.0204, 0.    ],
                [0.0548,   0.0269, 0.0257, 0.0022, 0.06,   0.3158, 0.0086, 0.    ],
                [0.0156,   0.0066, 0.0211, 0.0166, 0.0572, 0.0197, 0.0396, 0.2252],
                [0.0364,   0.001,  0.0034, 0.0005, 0.0277, 0.008,  0.0658, 0.1443]]

    columns_conn_probs = [col for col in eachcol(conn_probs)][1]
    layer_names = ["23E","23I","4E","4I","5E", "5I", "6E", "6I"]

    ccuf = Dict(
        k=>v for (k,v) in zip(layer_names,columns_conn_probs)
    )

    ccu = Dict("23E"=>20683, "23I"=>5834,
                "4E"=>21915, "4I"=>5479,
                "5E"=>4850, "5I"=>1065,
                "6E"=>14395, "6I"=>2948)

    ccu = Dict((k,ceil(Int64,v/35.0)) for (k,v) in pairs(ccu))
    cumulative = Dict() 
    v_old=1
    for (k,v) in pairs(ccu)
        cumulative[k]=collect(v_old:v+v_old)
        v_old=v+v_old
    end
    return (cumulative,ccu,ccuf,layer_names,columns_conn_probs,conn_probs)
end

function potjans_weights(Ncells)
    (cumulative,ccu,ccuf,layer_names,columns_conn_probs,conn_probs) = potjans_params()    

    Ncells = sum([i for i in values(ccu)])+1#max([max(i[:]) for i in values(cumulative)])
    w0Index_ = spzeros(Int,Ncells,Ncells)
    w0Weights = spzeros(Float32,Ncells,Ncells)
    edge_dict = Dict() 
    polarity = Dict()

    for src in 1:p.Ncells
        edge_dict[src] = Int64[]
        polarity[src] = ""
    end
    Ne = 0 
    Ni = 0
    @showprogress for (i,(k,v)) in enumerate(pairs(cumulative))
        for src in v
            for (j,(k1,v1)) in enumerate(pairs(cumulative))

                for tgt in v1
                    if src!=tgt
                        prob = conn_probs[i][j]
                        if rand()<prob
                            if occursin("E",k) 
                                if occursin("E",k1)          
                                    # TODO replace synaptic weight values.
                                    # w_mean = 87.8e-3  # nA
                                    w0Weights[tgt,src] = p.je#*#350.0#)/2.0
                                elseif occursin("I",k1)                    
                                    w0Weights[tgt,src] = p.jx#*150.0  
                                end
                                polarity[src]="E"
                                Ne+=1   

                            elseif occursin("I",k)
                                if occursin("E",k1)                    
                                    w0Weights[tgt,src] = -p.jx  
                                elseif occursin("I",k1)                    
                                    w0Weights[tgt,src] = -p.ji#*2.0  
                                end
                                polarity[src]="I"
                                Ni+=1

                            end
                            append!(edge_dict[src],tgt)
                            w0Index_[tgt,src] = tgt

                        end
                    end
                end
            end
        end

    end

    return (edge_dict,w0Weights,w0Index_,Ne,Ni)
end

function genPlasticWeights(args, w0Index, nc0, ns0)
    Ncells, frac, Ne, L, Lexc, Linh, Lffwd, wpee, wpie, wpei, wpii, wpffwd = map(x->args[x],
    [:Ncells, :frac, :Ne, :L, :Lexc, :Linh, :Lffwd, :wpee, :wpie, :wpei, :wpii, :wpffwd])

    (edge_dict,w0Weights,w0Index_,Ne,Ni) = potjans_weights(Ncells)

    ##
    # nc0Max is the maximum number of post synaptic targets
    # its a limit on the outdegree.
    # if this is not known upfront it can be calculated on the a pre-exisiting adjacency matrix as I do below.
    ##

    nc0Max = 0

    for (k,v) in pairs(edge_dict)
        templength = length(v)
        if templength>nc0Max
            nc0Max=templength
        end
    end

    #nc0Max = Ncells-1 # outdegree
    nc0 = Int.(nc0Max*ones(Ncells))
    w0Index = spzeros(Int,nc0Max,Ncells)
    for pre_cell = 1:Ncells
        post_cells = edge_dict[pre_cell]
        w0Index[1:length(edge_dict[pre_cell]),pre_cell] = post_cells
    end

    wpIndexIn = w0Index
    wpWeightIn = w0Weights
    ncpIn = nc0
    wpWeightFfwd = randn(rng, p.Ncells, p.Lffwd) * wpffwd

    return wpWeightFfwd, wpWeightIn, wpIndexIn, ncpIn
end

function genStaticWeights(args)
    Ncells, _, pree, prie, prei, prii, jee, jie, jei, jii = map(x->args[x],
            [:Ncells, :Ne, :pree, :prie, :prei, :prii, :jee, :jie, :jei, :jii])
    (edge_dict,w0Weights,w0Index_,Ne,Ni) = potjans_weights(Ncells)

    nc0Max = 0

    for (k,v) in pairs(edge_dict)
        templength = length(v)
        if templength>nc0Max
            nc0Max=templength
        end
    end

    #nc0Max = Ncells-1 # outdegree
    nc0 = Int.(nc0Max*ones(Ncells))
    w0Index = spzeros(Int,nc0Max,Ncells)
    for pre_cell = 1:Ncells
        post_cells = edge_dict[pre_cell]
        w0Index[1:length(edge_dict[pre_cell]),pre_cell] = post_cells
    end

    wpIndexIn = w0Index
    wpWeightIn = w0Weights
    ncpIn = nc0
    wpWeightFfwd = randn(rng, p.Ncells, p.Lffwd) * wpffwd

    return w0Index, w0Weights, nc0
end

The connection matrix gets made by the above function stack (as you can see with this ascii art): image

russelljjarvis commented 1 year ago

This is an output of a basic simulation via kind=:init

I have not yet appropriately sorted the connection matrix into top half excitatory bottom half inhibitory.

image

russelljjarvis commented 1 year ago

The final error that is hit before I am able to make plastic weights which are informed by the Potjans connectome is as follows:


Progress: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| Time: 0:00:03
mean excitatory firing rate: 57.04253393665159 Hz
mean inhibitory firing rate: 0.07420814479638009 Hz
Progress: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| Time: 0:00:03
ERROR: LoadError: BoundsError: attempt to access 521×2210 SparseMatrixCSC{Int64, Int64} with 403862 stored entries at index [522, 1]
Stacktrace:
 [1] throw_boundserror(A::SparseMatrixCSC{Int64, Int64}, I::Tuple{Int64, Int64})
   @ Base ./abstractarray.jl:703
 [2] checkbounds
   @ ./abstractarray.jl:668 [inlined]
 [3] getindex(A::SparseMatrixCSC{Int64, Int64}, i0::Int64, i1::Int64)
   @ SparseArrays ~/.julia/juliaup/julia-1.8.5+0.x64.linux.gnu/share/julia/stdlib/v1.8/SparseArrays/src/sparsematrix.jl:2228
 [4] top-level scope
   @ ~/git/ben/TrainSpikingNet.jl/src/init.jl:132
in expression starting at /home/rjjarvis/git/ben/TrainSpikingNet.jl/src/init.jl:129

The code it hits when it fails is here:

https://github.com/JuliaWSU/TrainSpikingNet.jl/blob/master/src/init.jl#L133-L140

bjarthur commented 1 year ago

hard to debug remotely, but the error is saying you're trying to index a dimension of a matrix that is of size 521 with an index of 522. some of the indices in the connection matrix must exceed the number of neurons.

and why did you modify the dimensions of wpIndexConvert in init.jl in your fork? you shouldn't have to make modifications to any of the code, other than your params.jl file and custom plugins, to get a custom adjacency matrix to work.

more generally speaking, it would be great to have a set of plugins which could convert from the allen format (or any other format of interest, e.g. sonata) into that required for trainspiking net. am happy to help with this and would merge a PR

russelljjarvis commented 1 year ago

This is a good point you made.

https://github.com/JuliaWSU/TrainSpikingNet.jl/blob/master/src/init.jl#L121

I will try to revert some of these unnecessary experimental changes to see if that fixes things.

I can also redo the work against the latest source code, even if GPU tests were breaking I only need CPU code to test this with.

Before I do migrate to the new source code, on this line I introduced a condition where postsynaptic cell indexs can't be zero: https://github.com/JuliaWSU/TrainSpikingNet.jl/blob/master/src/cpu/loop.jl#L269

I can think of a possibly important difference between biological connectomes and the erdos_random ones.

The Allen and Potjans connectomes don't have fixed outdegree per neuron. You can't easily artificially pack these connectomes into dense containers, as efferent degree can vary a lot between cells in the biological cases (1 axonal efferent projecting synapse, or 10,000). The sparse matrix has some mandatory zero elements. I have included a spy plot that shows how some cells don't necessarily project onto others.

image

I am coming around to the view that I still don't probably properly understand the roles of the different matrices.

Maybe the index containing matrices can be free from zero values, but they might have a ragged array shape for example: w0Index, but the matrix: w0Weights, might be rectangular and can contain zeros?

I am sometimes unsure if there is a conceptual constraint of the current approach where the connectome must be zero value free/dense.

russelljjarvis commented 1 year ago

hard to debug remotely, but the error is saying you're trying to index a dimension of a matrix that is of size 521 with an index of 522. some of the indices in the connection matrix must exceed the number of neurons.

and why did you modify the dimensions of wpIndexConvert in init.jl in your fork? you shouldn't have to make modifications to any of the code, other than your params.jl file and custom plugins, to get a custom adjacency matrix to work.

more generally speaking, it would be great to have a set of plugins which could convert from the Allen format (or any other format of interest, e.g. sonata) into that required for TrainSpikingNet. am happy to help with this and would merge a PR

Oh cool, yes I am working on this:

more generally speaking, it would be great to have a set of plugins which could convert from the Allen format (or any other format of interest, e.g. sonata) into that required for TrainSpikingNet net. am happy to help with this and would merge a PR

I will try to send through some code options, maybe as public gists before escalating to PRs.

There are a few options:

(i) and (iii) seem the most pragmatic, they will lead to quicker results and better lean on the Sonata file specification. The .json/.py network configurations in Sonata, are succinct files, and the Sonata adaptor seems to lead to a way of sub-setting/ downscaling biological networks proportionately.

bjarthur commented 1 year ago

The Allen and Potjans connectomes don't have fixed outdegree per neuron. You can't easily artificially pack these connectomes into dense containers, as efferent degree can vary a lot between cells in the biological cases (1 axonal efferent projecting synapse, or 10,000). The sparse matrix has some mandatory zero elements. I have included a spy plot that shows how some cells don't necessarily project onto others.

Maybe the index containing matrices can be free from zero values, but they might have a ragged array shape for example: w0Index, but the matrix: w0Weights, might be rectangular and can contain zeros?

I am sometimes unsure if there is a conceptual constraint of the current approach where the connectome must be zero value free/dense.

ragged! yes, wpWeightIn is ragged, and ncpIn tells how many elements of each column are meaningful. the rest are ignored.

at least they should be! i just discovered yesterday that the train script in fact assumes that wpWeightIn is not ragged and ignores ncpIn. am working on a fix now, but the init script should work fine. so this is not (yet) your problem.

one possible fix i'm entertaining is to make wpWeightIn a vector of vectors. for the CPU code this would be more memory efficient for architectures which have wildly different fan-in degrees. and more intuitive to new users. would need to think about how best to refactor the GPU code to handle this data structure without incurring a speed penalty.

russelljjarvis commented 1 year ago

ragged! yes, wpWeightIn is ragged, and ncpIn tells how many elements of each column are meaningful. the rest are ignored.

at least they should be! i just discovered yesterday that the train script in fact assumes that wpWeightIn is not ragged and ignores ncpIn. am working on a fix now, but the init script should work fine. so this is not (yet) your problem.

one possible fix i'm entertaining is to make wpWeightIn a vector of vectors. for the CPU code this would be more memory efficient for architectures which have wildly different fan-in degrees. and more intuitive to new users. would need to think about how best to refactor the GPU code to handle this data structure without incurring a speed penalty.

Okay well I have solved the issue, some of the conversation points are ongoing.

It would be awesome if the user only had to supply a connection matrix, indices of excitatory neurons and indices of inhibitory neurons. That would mean the user could achieve wiring by unpacking a JLD matrix instead of writing the code that defines the connectivity.

I tried to do self PRs to update my forks of TrainSpikingNet.jl but resolving those merges could take too long.

The following lines of the init file achieve something like a basic Potjans network simulation with the plastic synapse connectivity being sort of like a stochastic copy of the static weights.

These lines show how I am just trying to test basic simulation ability of Potjans connectome without training and testing. https://github.com/JuliaWSU/TrainSpikingNet.jl/blob/master/src/init.jl#L72-L242

https://github.com/JuliaWSU/TrainSpikingNet.jl/blob/master/enter_potjans_connectome.sh

Some of this code should be version independent of TrainingSpikingNet.jl, so I have put the code in gists. This code is starting to get more mature now: https://gist.github.com/russelljjarvis/1aa265c749b6885489d38eb111fb1c61

and this code is an intervention to get the other code to run https://gist.github.com/russelljjarvis/6af437483449682abbc92069d7e9d50e

I was thinking I could fast track to the latest version of the source code, by doing a mix and match of a vanilla call to init.jl, and doing a JLD read of outputs from the potjans_matrix code that is in a gist atm.

This is an unfortunate hack I ended up using to get the simulation to complete. https://github.com/JuliaWSU/TrainSpikingNet.jl/blob/master/src/init.jl#L143

The main error looks like a propagating index off by 1 error, rather than figuring it out I figured it was okay to if the simulated plastic weights where missing a whole row or column.

It makes more sense for me to try to reinterface with a newer revision of the code, you describe above rather than perfect my approach relative to old code.

I can polish this and make it a file in the TrainSpikingNet.jl somewhere.

russelljjarvis commented 1 year ago

This issue could optionally be closed now too.

I wonder if there is scope for a Discussions page? I have more technical code questions to ask, but they are not "issues" ie they are not code breaks.

bjarthur commented 1 year ago

I wonder if there is scope for a Discussions page? I have more technical code questions to ask, but they are not "issues" ie they are not code breaks.

great idea! i have turned on the Discussions feature for this github repo. you could also post to https://discourse.julialang.org/ too for general questions. but github is probably better.

bjarthur commented 1 year ago

It would be awesome if the user only had to supply a connection matrix, indices of excitatory neurons and indices of inhibitory neurons.

supplying such a connection matrix is all you should have to do. what makes you think more is needed?

That would mean the user could achieve wiring by unpacking a JLD matrix instead of writing the code that defines the connectivity.

reading in a connection matrix from a JLD file, that is created elsewhere, is a totally valid thing to do inside the gen{Static,Plastic}Weights plugins. in the end, all that matters is that those plugins return connection matrices. whether they compute them on the fly, or read them from a JLD, doesn't matter.

Some of this code should be version independent of TrainingSpikingNet.jl, so I have put the code in gists.

ideally all of your code should be independent of TrainSpikingNet.jl! users shouldn't have to fork and modify it. that's the entire point of the plugin system.

I was thinking I could fast track to the latest version of the source code, by doing a mix and match of a vanilla call to init.jl, and doing a JLD read of outputs from the potjans_matrix code that is in a gist atm.

definitely recommend upgrading to the new version, and using your potjans gist as a custom plugin. again, you shouldn't have to modify init.jl , or anything else in TrainSpikingNet.jl, to get this to work. simply split the gist into two files, genStaticWeights-potjans.jl and genPlasticWeights-potjans.jl, and then in your param.jl file set genStaticWeights_file = genStaticWeights-potjans.jl etc.

I can polish this and make it a file in the TrainSpikingNet.jl somewhere.

i'm definitely amenable to creating a new "contrib/" directory for custom plugins the community has written.

russelljjarvis commented 1 year ago

i'm definitely amenable to creating a new "contrib/" directory for custom plugins the community has written.

Okay I will change the name of the directory I made in the draft PR from connectomes to contrib

definitely recommend upgrading to the new version, and using your potjans gist as a custom plugin. again, you shouldn't have to modify init.jl , or anything else in TrainSpikingNet.jl, to get this to work. simply split the gist into two files, genStaticWeights-potjans.jl and genPlasticWeights-potjans.jl, and then in your param.jl file set genStaticWeights_file = genStaticWeights-potjans.jl etc.

I will make an example of code that uses the pluggin system just for a static connectome (no plastic weights added yet, plots spikes and quits). The code might fail, but it would be nice to get your input about what went wrong, otherwise I might spend days hacking the wrong bit of code trying to accomodate an unforseen edge case. Its probably better if I can get the example (or counter example runnable on your end).

russelljjarvis commented 1 year ago

reading in a connection matrix from a JLD file, that is created elsewhere, is a totally valid thing to do inside the gen{Static,Plastic}Weights plugins. in the end, all that matters is that those plugins return connection matrices. whether they compute them on the fly, or read them from a JLD, doesn't matter.

This might work now, I will test it out somehow.

russelljjarvis commented 1 year ago

The new usage pattern in the README is so awesome!

russelljjarvis commented 1 year ago

I made a PR, that shows how using the latest code changes, I can get pretty close to using plugin style connectivity matrices. https://github.com/SpikingNetwork/TrainSpikingNet.jl/pull/4/files StaticMatrices.jl almost work, PlasticWeightMatrices fail.

I wonder if it is possible to opt out of plastic simulations, by providing an empty Plastic weight Matrix?

The entry point to the code I provided is this:

TrainSpikingNet.jl/src/contrib$ julia configAndRunPotjans.jl 

This could also be regarded as a draft PR too. I might update it later.

Inside my code where I build the connection matrix, I save the matrix but I need to probably save it in the data-dir that the program user specifies, instead I save it in the pwd which might be src/contrib.

I should update the code so any matrices are stored in the user specified data_dir.

russelljjarvis commented 1 year ago

definitely recommend upgrading to the new version, and using your potjans gist as a custom plugin. again, you shouldn't have to modify init.jl , or anything else in TrainSpikingNet.jl, to get this to work. simply split the gist into two files, genStaticWeights-potjans.jl and genPlasticWeights-potjans.jl, and then in your param.jl file set genStaticWeights_file = genStaticWeights-potjans.jl

I have mostly taken on this approach.

I was thinking I don't really need to write the a plastic connection matrix rule myself straight away. I can fall back on the erdos one as an intermediate step to debugging this code approach.