SpikeInterface / spikeinterface-gui

GUI for spikeinterface objects
MIT License
21 stars 7 forks source link

Sleeping error in spikeinterface_gui/controller.py #61

Open rat-h opened 5 months ago

rat-h commented 5 months ago

I try to create a fake spike sorting which just picks huge spikes and ignores everything else. So the code for this task is pretty simple and works well:

from numpy import *
import psutil, os
import spikeinterface.full as si
from probeinterface import read_probeinterface
from spikeinterface.sortingcomponents.peak_detection import detect_peaks
from spikeinterface.core.basesorting import minimum_spike_dtype
from spikeinterface.core import NumpySorting

job_kwargs = {
        "n_jobs": 14,
        "total_memory": f"{int(ceil(psutil.virtual_memory()[1]*0.75))//1024//1024//1024:d}G",
        "progress_bar": True
    }
prob = read_probeinterface("probes/A4x32-Poly2-5mm-23s-200-177-after-mapping.json").probes[0]

d = array([ sqrt( sum( (p0-p1)**2 ) ) for pid,p0 in enumerate(prob.contact_positions) for p1 in prob.contact_positions[pid+1:,:] ])
dmin, dmax = amin(d), amax(d)
recording = si.BinaryRecordingExtractor( "continuous.dat",30000.0, 'int16', num_channels=128)
recording.set_probe(prob,in_place=True)
preproc = [recording ]#.remove_channels([17])]
preproc.append(
    si.filter(preproc[-1],btype="bandpass",band=[1500.,6000.])
)
preproc.append(
    #si.common_reference(preproc[-1],reference="global",operator="median")
    si.common_reference(preproc[-1],reference="local",operator="median",local_radius=(0,200))
)

rec = preproc[-1].save( folder="detect-highamp/preprocessed", chunk_duration='1m',overwrite=True,**job_kwargs)

# rec = si.load_extractor("detect-highamp/preprocessed")
pdk = detect_peaks(
    rec,
    # si.zscore(rec),
    method='locally_exclusive_torch',
    peak_sign='neg',          # Default 'neg'
    detect_threshold=5.5,      # Default 5
    exclude_sweep_ms=2.0,      # Default 0.1
    radius_um=int(ceil(dmin)), # Default ?
)

print('===========================')
print(f'Number of spikes = {pdk.size: 7d}')
print('===========================')
# print(minimum_spike_dtype)
# exit(0)
final_spikes = zeros(pdk.size, dtype=minimum_spike_dtype)
final_spikes["sample_index"] = pdk['sample_index']
final_spikes["unit_index"  ] = pdk['channel_index']
labels_set = unique(pdk['channel_index'])
# final_spikes["segment_index"] = spikes["segment_index"]

srt = NumpySorting(final_spikes, 30000.0, labels_set)
os.system('rm -fR detect-highamp/sorting-saved detect-highamp/waveforms detect-highamp/phy-extractor detect-highamp/phy')
srt = srt.save(folder="detect-highamp/sorting-saved")
we = si.extract_waveforms(
            rec, srt, 'detect-highamp/waveforms',
            max_spikes_per_unit=500,
            ms_before=1.5, ms_after=2.5,
            **job_kwargs
        )

I read the results of this "sortering" by my usual tool:

import sys,os
import spikeinterface.full as si
import spikeinterface_gui

#

# This cerate a Qt app
app = spikeinterface_gui.mkQApp() 
# reload the waveform folder
try:
    rc  = si.load_extractor(sys.argv[1]+"/preprocessed")
except BaseException as e:
    print(f"Cannot preprocessed data from {sys.argv[1]}/preprocessed: {e}")
    exit(1)

if os.path.isdir(sys.argv[1]+'/waveforms-clean'):
    try:
        we  = si.WaveformExtractor.load_from_folder(sys.argv[1]+'/waveforms-clean')
    except BaseException as e:
        print(f"Cannot load waveform from {sys.argv[1]}/waveforms-clean: {e}")
        exit(1)
    print('have read waveforms-clean')
elif os.path.isdir(sys.argv[1]+'/waveforms'):
    try:
        we  = si.WaveformExtractor.load_from_folder(sys.argv[1]+'/waveforms')
    except BaseException as e:
        print(f"Cannot load waveform from {sys.argv[1]}/waveforms: {e}")
        exit(1)
    print('have read waveforms')
else:
    exit(1)
pca = si.compute_principal_components(we, load_if_exists=True, mode='by_channel_local', n_components=3)
ccg = si.compute_correlograms(we,load_if_exists=True, window_ms=2000, bin_ms=1)
isi = si.compute_isi_histograms(we,load_if_exists=True, window_ms=2000, bin_ms=1)
#DB>>
sr  = si.load_extractor(sys.argv[1]+'/sorting-saved')
srspikes = sr.to_spike_vector()
wespikes = we.sorting.to_spike_vector()
print(srspikes.size)
print(wespikes.size)
#<<DB
# create the mainwindow and show
win = spikeinterface_gui.MainWindow(we)
win.show()
# run the main Qt6 loop
app.exec()

The problem is that:

  1. if I set detect_threshold in the "sorter" below 5., which gives "spikes" at each electrode - the viewer works just fine
  2. If I set detect_threshold higher than 5, means only a few spikes are selected, it returns an error
python read-sorting.py detect-highamp
have read waveforms
10419
10419
/home/rth/.local/apps/spikes/lib/python3.10/site-packages/spikeinterface/postprocessing/unit_localization.py:366: RuntimeWarning: invalid value encountered in divide
  com = np.sum(wf_data[:, np.newaxis] * local_contact_locations, axis=0) / np.sum(wf_data)
Traceback (most recent call last):
  File "/home/rth/spikeinterface/read-sorting.py", line 44, in <module>
    win = spikeinterface_gui.MainWindow(we)
  File "/home/rth/.local/apps/spikes/lib/python3.10/site-packages/spikeinterface_gui/mainwindow.py", line 23, in __init__
    self.controller = SpikeinterfaceController(waveform_extractor, verbose=verbose)
  File "/home/rth/.local/apps/spikes/lib/python3.10/site-packages/spikeinterface_gui/controller.py", line 115, in __init__
    self.spikes['sample_index'] = spikes_['sample_index']
ValueError: could not broadcast input array from shape (10419,) into shape (9201,)

Note, that number of spikes are the same in saved sorter and in we.sorting, but somehow it is different inside the GUI. Any ideas what it may be?