neuromorphicsystems / GENN_On_Jetson

A Basic Guide for configuring the Nvidia Jetson Nano for running the GENN Potjan et al cortical models developed by Knight et al.
0 stars 0 forks source link

Get this running on deeplearning1.westernsydney.edu.au #2

Open russelljjarvis opened 2 years ago

russelljjarvis commented 2 years ago

Rational:

Machine learning models usually apply meta-parameter optimization to find the learning parameters that lead to the better learning and better cross-validation. So to here, to get a bio-plausible model to learn prediction we may need to explore parameters for:

This leads to a combinatoric explosion of meta-parameters to explore, and to do this systematically we need to explore parameters using genetic algorithms on a fast to execute network model (pure GENN cpp/cuda code).

Goal:

Try badly port the predictive coding model to GENN using Brian2Genn. Its expected that this will only partially translate the brian/teili code to GENN code.

russelljjarvis commented 2 years ago

Hi @pabloabur,

I was thinking more about this.

I think its a good idea, because, we can presume that the mammal brain changes its connectome structure routinely to maximise predictive coding. If the brain is optimizing itself to maximise predictive coding, then its reasonable to use machine optimization to do the same thing.

Without going into a fully blown optimization routine, if you can just send me Python Pickles I can calculate the multivariate spike distance matrix for you.

I can also calculate the spike distance between the input events (like your rotating bar spikes), and the readout layer.

When the spike distance between the input and the readout is high, than that is a large error for optimization, if the spike distance is low, then that is a small error.

Here are some examples: https://github.com/mariomulansky/PySpike https://github.com/russelljjarvis/2Dto3DSpikes/blob/main/julia/read_spikes.jl#L222-L262

Note we wouldn't have to re-implement your teili/brian2 simulation, we could just add in spike distance errors first, and then slowly build in optimization.

https://github.com/russelljjarvis/SpikeNetOpt.jl/blob/main/img/net_compare_unicode.png

pabloabur commented 2 years ago

I'm still working on that as it looks like we will end up doing more things than just running these after the simulations, but I pickled the spikes of the recurrent network in case you want to play with i spikes.zip t

russelljjarvis commented 2 years ago

Excellent!

russelljjarvis commented 2 years ago
import pickle
from pyspike import SpikeTrain as SpikeTrainPy
import pyspike as spk
import numpy as np
import matplotlib.pyplot as plt
with open("spike_times","rb") as f:
    spike_times = pickle.load(f)
with open("spike_indices","rb") as f:
    sind = pickle.load(f)

sts = []
DURATION_MS = np.max(spike_times)
for i in set(sind):
    spike_times_ = []
    for st,ind in zip(spike_times,sind):
        if ind == i:
            spike_times_.append(st)
    sts.append(SpikeTrainPy(spike_times_,edges=(0.0,DURATION_MS)))

for i,st in enumerate(sts):
    if i>0 and i<len(sts):
        spike_profile = spk.spike_profile(st, sts[1])
        print("SPIKE distance: %.8f" % spike_profile.avrg())
russelljjarvis commented 2 years ago

With spade (pattern detection [like cell assembly detection])

import pickle
from pyspike import SpikeTrain as SpikeTrainPy
from neo import SpikeTrain as SpikeTrainN
import quantities as pq
import numpy as np
import elephant.cell_assembly_detection as cad
import elephant.spade as spade
#from elephant.cell_assembly_detection import cell_assembly_detection
import viziphant
import pyspike as spk
import numpy as np
import elephant.conversion as conv
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('Agg')

with open("spike_times","rb") as f:
    spike_times = pickle.load(f)
with open("spike_indices","rb") as f:
    sind = pickle.load(f)

sts = []
stsn = []
DURATION_MS = np.max(spike_times)
for i in set(sind):
    spike_times_ = []
    for st,ind in zip(spike_times,sind):
        if ind == i:
            spike_times_.append(st)
    # a list of PySpike spike trains
    sts.append(SpikeTrainPy(spike_times_,edges=(0.0,DURATION_MS)))
    # a list of neo object spike trains
    stsn.append(SpikeTrainN(sts[-1].get_spikes_non_empty() * pq.ms, t_stop=DURATION_MS * pq.ms))

"""
Get spike distances between pairs ie (between inputs and outputs)
"""
sum_over_all=[]
for i,st in enumerate(sts):
    if i>0 and i<len(sts):
        spike_profile = spk.spike_profile(st, sts[1])
        print("SPIKE distance: %.8f" % spike_profile.avrg())
        sum_over_all.append(spike_profile.avrg())
print("the final scalar value to use if this really where inputs versus readout bivariate spike distance")
print("total {0}".format(np.sum(sum_over_all)))

spike_profile = spk.spike_profile(sts)
x, y = spike_profile.get_plottable_data()
plt.plot(x, y, '--k')
avg = spike_profile.avrg()
print("SPIKE distance: %.8f" % avg)
plt.savefig("spike_distance_profile.png")
spike_distance = spk.spike_distance_matrix(sts, interval=(0,DURATION_MS))
plt.figure()
plt.imshow(spike_distance, interpolation='none')
plt.savefig("spike_distance.png")
patterns = spade.spade(stsn, bin_size=50 * pq.ms,
                       winlen=1)['patterns']
axes = viziphant.patterns.plot_patterns(stsn, patterns)
figure = plt.figure()
plt.savefig("patterns0.png")

figure = plt.figure()
rasterplot_rates(stsn)
plt.savefig("raster_rate_rug.png")
figure = plt.figure()
axes = viziphant.patterns.plot_patterns_statistics_all(patterns)
plt.savefig("patterns_all.png")
figure = plt.figure()
cv_list = [cv(isi(spiketrain)) for spiketrain in stsn]
plt.hist(cv_list)
plt.savefig("cv_hist.png")

np.random.seed(0)
binned_spiketrains = BinnedSpikeTrain(stsn, bin_size=5*pq.ms)

plt.figure()
for neu in patterns['neurons']:
    if neu == 0:
        plt.plot(
        patterns['times']*binsize, [neu]*len(patterns['times']),
        'ro', label='pattern')
    else:
        plt.plot(patterns['times']*binsize, [neu] * len(patterns['times']),'ro')
# Raster plot of the data
for st_idx, st in enumerate(sts):
    if st_idx == 0:
        plt.plot(st.rescale(pq.ms), [st_idx] * len(st), 'k.',
        label='spikes')
    else:
        plt.plot(st.rescale(pq.ms), [st_idx] * len(st), 'k.')
plt.ylim([-1, len(sts)])
plt.xlabel('time (ms)')
plt.ylabel('neurons ids')
plt.legend()
plt.show()