SpikeInterface / spikeinterface

A Python-based module for creating flexible and robust spike sorting pipelines.
https://spikeinterface.readthedocs.io
MIT License
527 stars 187 forks source link

Unable to open kilosort4 results with phy #2710

Closed Sachuriga closed 7 months ago

Sachuriga commented 7 months ago

Dear,

I'm using spikeinterface==0.100.5 and kilosort4== .4.0.3, on 64 channel 6 shanks probe. The phy output is works perfectly with the ironclust generated results.

However, I have encountered some issues when open the kilosort4 phy output folder. And I saw there are some silimer reports, and I tried the following things:

  1. Before waveform extraction sorting.remove_empty_units() and sc.remove_excess_spikes(sorting_no_empty,rec_corrected1)
  2. Set sparse=True or sparse=False or sparse=True, method="by_property",by_property="group" in the si.extract_waveforms()
  3. sparsity=None or sparsity=True in the sexp.export_to_phy() None of above was works, and the phy is giving me the error message as below. I'm wondering if there any solution or any parameter I can adjust?

Error message:

19:51:55.272 [E] __init__:62          An error has occurred (AssertionError):
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "C:\Users\sachur\AppData\Local\anaconda3\envs\phy2\Scripts\phy.exe\__main__.py", line 7, in <module>
    sys.exit(phycli())
             ^^^^^^^^
  File "C:\Users\sachur\AppData\Local\anaconda3\envs\phy2\Lib\site-packages\click\core.py", line 1157, in __call__
    return self.main(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\sachur\AppData\Local\anaconda3\envs\phy2\Lib\site-packages\click\core.py", line 1078, in main
    rv = self.invoke(ctx)
         ^^^^^^^^^^^^^^^^
  File "C:\Users\sachur\AppData\Local\anaconda3\envs\phy2\Lib\site-packages\click\core.py", line 1688, in invoke
    return _process_result(sub_ctx.command.invoke(sub_ctx))
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\sachur\AppData\Local\anaconda3\envs\phy2\Lib\site-packages\click\core.py", line 1434, in invoke
    return ctx.invoke(self.callback, **ctx.params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\sachur\AppData\Local\anaconda3\envs\phy2\Lib\site-packages\click\core.py", line 783, in invoke
    return __callback(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\sachur\AppData\Local\anaconda3\envs\phy2\Lib\site-packages\click\decorators.py", line 33, in new_func
    return f(get_current_context(), *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\sachur\AppData\Local\anaconda3\envs\phy2\Lib\site-packages\phy\apps\__init__.py", line 159, in cli_template_gui
    template_gui(params_path, **kwargs)
  File "C:\Users\sachur\AppData\Local\anaconda3\envs\phy2\Lib\site-packages\phy\apps\template\gui.py", line 209, in template_gui
    model = load_model(params_path)
            ^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\sachur\AppData\Local\anaconda3\envs\phy2\Lib\site-packages\phylib\io\model.py", line 1433, in load_model
    return TemplateModel(**get_template_params(params_path))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\sachur\AppData\Local\anaconda3\envs\phy2\Lib\site-packages\phylib\io\model.py", line 339, in __init__
    self._load_data()
  File "C:\Users\sachur\AppData\Local\anaconda3\envs\phy2\Lib\site-packages\phylib\io\model.py", line 358, in _load_data
    assert self.amplitudes.shape == (ns,)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError

QWidget: Must construct a QApplication before a QWidget
zm711 commented 7 months ago

I think we would probably need to see the full script. How you're generating the waveform_extractor, calculating amplitudes, etc. Just to see if we can see anything. That error is saying that the number of amplitudes is not matching the number of spikes, so we just need to see how you're generating these values.

Sachuriga commented 7 months ago

Hey,

here is the script I'm using.

import spikeinterface as si
import spikeinterface.extractors as se
import spikeinterface.preprocessing as spre
import spikeinterface.sorters as ss
import spikeinterface.postprocessing as spost
import spikeinterface.qualitymetrics as sqm
import spikeinterface.comparison as sc
import spikeinterface.exporters as sexp
import spikeinterface.widgets as sw
import spikeinterface.curation as scur
from spikeinterface.preprocessing import (bandpass_filter, notch_filter, common_reference, highpass_filter, zscore,
                                          remove_artifacts, preprocesser_dict, normalize_by_quantile, center,
                                          correct_motion, load_motion_info)
import numpy as np
import os
from pathlib import Path
import warnings
import sys
from spikeinterface.sorters import installed_sorters
import probeinterface as pi
from spikeinterface.postprocessing import compute_principal_components

##Load the data
base_folder = Path(".")
folder=r'file_path'
current_file_name_path = os.path.basename(folder)
current_file_name = str(folder)
stream_name = 'Record Node 101#OE_FPGA_Acquisition_Board-100.Rhythm Data'
recording = se.read_openephys(folder, stream_name=stream_name, load_sync_timestamps=True)

##load the probe file
manufacturer = 'cambridgeneurotech'
probe_name = 'ASSY-236-F'
probe = pi.get_probe(manufacturer, probe_name)
mapping_to_device = [
    # connector J2 TOP
    41, 39, 38, 37, 35, 34, 33, 32, 29, 30, 28, 26, 25, 24, 22, 20,
    46, 45, 44, 43, 42, 40, 36, 31, 27, 23, 21, 18, 19, 17, 16, 14,
    # connector J1 BOTTOM
    55, 53, 54, 52, 51, 50, 49, 48, 47, 15, 13, 12, 11, 9, 10, 8,
    63, 62, 61, 60, 59, 58, 57, 56, 7, 6, 5, 4, 3, 2, 1, 0
]
probegroup = pi.ProbeGroup()
probegroup.add_probe(probe)
pi.write_prb(f"{probe_name}.prb", probegroup, group_mode="by_shank")
recording_prb = recording.set_probe(probe, group_mode="by_shank")

##prep for the wf extraction
recording_f0 = bandpass_filter(recording_prb, freq_min=600, freq_max=8000)
rec_cmr = common_reference(recording_f0, reference='global', operator='median')
recording_sub = rec_cmr

preprocessed = "_" + "preprocessed_kss4"
recording_saved = recording_sub.save(folder=base_folder / preprocessed, overwrite=True)
recording_rec = si.load_extractor(base_folder / preprocessed)

## Sorting
recording_rec = si.load_extractor(base_folder / preprocessed)
ks4 = current_file_name + "_" + "ks4"
para_ks4 = {'skip_kilosort_preprocessing': False,
            'nblocks': 0,
            'nearest_templates': 9,
            'min_template_size': 17,
            'batch_size': 120000,
            'do_correction':False,
            'dmin':15,
            'dminx':16.5,
            'whitening_range':8}

sorting = ss.run_sorter_by_property('kilosort4',
                                    recording = recording_prb,
                                    working_folder=base_folder / ks4,
                                    grouping_property='group', **para_ks4)

sorting_no_empty = sorting.remove_empty_units()
s_clean = sc.remove_excess_spikes(sorting_no_empty,recording_rec)

##extract wf
waveforms = "waveforms_ks4"
we = si.extract_waveforms(rec_cmr,
                          s_clean, 
                          sparse=True, 
                          method="by_property",
                          by_property="group")

we = si.load_waveforms(base_folder / waveforms, 
                       with_recording=True, 
                       sorting=s_clean)

spike_locations = spost.compute_unit_locations(waveform_extractor=we,
                                               method='grid_convolution')
metrics = sqm.compute_quality_metrics(waveform_extractor=we)

#export
phy_ks4 = current_file_name + "_" + "phy_ks4"
sexp.export_to_phy(we,
                   output_folder=base_folder / phy_ks4,
                   remove_if_exists=True,
                   compute_amplitudes=True,
                   copy_binary=True,
                   compute_pc_features=False)
zm711 commented 7 months ago

One thing I'm wondering is KS4 has some bugs where it returns negative spike times. Could you check that?

spike_vector = we.sorting.to_spike_vector()
spike_times_samples = spike_vector['sample_index']
print(spike_times_samples[:10]) # first x spike times.

I think the remove excess_spikes was originally written for excessive at the end of the recording and may not account for negative spike times if they exist. So we should check that first since it is a known KS4 bug.

Sachuriga commented 7 months ago

Hey,

You're right! indeed there are negative values.

[-10 -10 -10 -10 -10 -10 -10 -10 -10 -10]

and with

print(np.where(spike_times_samples<0))

output: (array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
       17, 18, 19, 20, 21, 22, 23], dtype=int64),)
alejoe91 commented 7 months ago

Let's patch the remove_excess_spikes to deal with this too!

Sachuriga commented 7 months ago

Thanks!

zm711 commented 7 months ago

@alejoe91 we want this in the 100 bug fixes and in main?

alejoe91 commented 7 months ago

Probably! Same for the MS5 fix!

zm711 commented 7 months ago

I opened a PR for main. Do you want one big PR with a bunch of these sorter fixes on 100.bug.fixes or separate PRs to port all of these over. (We have ms5, we have KS4 for 4.0.3, we have this excess spike fixes).

alirza67 commented 7 months ago

Hi there, I have updated my spikeinterface (v: 0.101.0) to solve the problem of negative spike times, but I see no effect. Here is what I have. Can you please let me know what I am doing wrong?:

sorting = si.read_sorter_folder(base_folder + '/KS4out')
rec=si.load_extractor(base_folder + '/multirec')

sorting2 = si.remove_excess_spikes(sorting, rec)
spike_vector = sorting2.to_spike_vector()
spike_times_samples = spike_vector['sample_index']
print(spike_times_samples[:10]) # first x spike times

The output is the same as for "sorting":

[ -2 18 20 37 43 45 60 84 86 106]

zm711 commented 7 months ago

I see what the issue is. Because you have no spikes at 0 the indexing we did pull out one last negative spike time. Thanks for the report! We will fix the fix :) @alejoe91 any ideas for this? I did side='left' because if we have spikes at 0 they could technically be real, but maybe we just do side=right at the risk of losing one extra spike? EDIT: @alejoe91 I just added a PR that we can check and decide on or close if we come up with a better solution!

Sachuriga commented 7 months ago

Hi @zm711 ,

I have also tried the updated version, but it's not working for me either. I thought it might be easier if I send you a link to download my data. so you could test on. Would that be ok with you?

zm711 commented 7 months ago

Do either of you @sachuriga283 or @alirza67 feel comfortable installing SI from a PR/branch and testing my new fix (you would install #2727? That would actually be easiest. If not if you could just send me a link to the data to recreate the sorting that would be enough for me to test locally :)

Sachuriga commented 7 months ago

Hi, I would like to test your fixes! But I'm not sure if I'm installing your updates properly. Would you please let me how should I install Si from you PR/branch?

zm711 commented 7 months ago

How are you using git? You would just checkout my branch git docs. If you're using a git gui (like the one github provides I can include pictures).

zm711 commented 7 months ago

Close, but we want the pr branch.

git clone --single-branch --branch neg-spikes-v2 https://github.com/zm711/spikeinterface.git

If you already have the editable install you could also switch to the branch, but if it is easier to just do this branch separately and then delete it that is fine too.

Sachuriga commented 7 months ago

Cool! thanks, I will try it out

alirza67 commented 7 months ago

Hi @zm711 , Thanks for the prompt responses. I cloned your branch and reinstalled SI, but the problem remains!

Sachuriga commented 7 months ago

@zm711 Hi, I have tried the new fixes and it's seems not working. And I'm not sure if it just because I'm not installing properly, so I will just paste the link to download the data, then if you could try this locally will be helpful. https://drive.google.com/drive/folders/1kP5wu5OenDif8KBdLA0WnErN3wExgqqi?usp=drive_link

zm711 commented 7 months ago

I'll check out your data @sachuriga283. Thanks. Could you both post which negative spike is still present with my new patch? I'm also trying to figure out where the problem is remaining.

zm711 commented 7 months ago

@sachuriga283, the drive link is not open for access. Was that intentional?

Sachuriga commented 7 months ago

Hi @zm711, sorry, it was not intentional, I have forgot open the access, now you should have the access

zm711 commented 7 months ago

Thanks!

In order to load just the continuous.dat without the full folder structure I need to know the dtype, channel number, sampling_rate. Could you supply that for me @sachuriga283

Never mind my internet gave out so it didn't download the other file...

zm711 commented 7 months ago

@alirza67 and @sachuriga283 I figured it out. I hadn't updated the has_exceeding_spikes function that checks before removing spikes. I've now updated that in the PR if either of you could test that! I tested with @sachuriga283's data and it worked for me now!

zm711 commented 7 months ago

We merged my fix into main so go ahead and try to install from the main and make sure it works.

alirza67 commented 7 months ago

Hi @zm711 ,

Unfortunately it did not solve my problem!

on the other hand I updated my kilosort4 (now I have v4.0.5). The negative spike time problem is solved on their side. What I understood that ks4 initially had some unintended behavior for the first batch which is fixed now. And I confirmed that by running ks4 on my data without using SI.

However, when I use SI with KS4, I still get that one negative spike time. Apart from the removing invalid spike times in the postprocessing steps. Do you have any guess why we have them in the first place?

I have 2 OpenEphys recording sessions. For "KS4 only" pipeline, I concatenate the binary files using numpy.memmap and run the KS4 gui. For "SI" pipeline, I load single recording into SI and concatenated them as a multi recording. Then feed it to KS4 as follow:

base_folder  = 'D:/ephys/R002/2024-03-13'

def get_subfolders(directory):
    subfolders = [f.path for f in os.scandir(directory) if f.is_dir() and f.name[-2:].isdigit()]
    return subfolders

rec_subfolders = get_subfolders(base_folder)

recordings_list = []
for dir in rec_subfolders:
    recordings_list.append( si.read_openephys(dir,stream_id='1'))

# concatenate recordings
multirecording = si.concatenate_recordings(recordings_list)
single_rec = si.read_openephys(dir,stream_id='1')

# set a probe
multirecording = multirecording.set_probe(single_rec.get_probe())

# save to binary recording to be used in KS4
job_kwargs = dict(n_jobs=30, chunk_duration='1s', progress_bar=True)
rec = multirecording.save(folder=base_folder + '/multirec', format='binary', **job_kwargs)

# spike sorting 
sorting = si.run_sorter('kilosort4', rec, output_folder=base_folder +'/KS4out_nosipp2',
                        docker_image=False, verbose=True)

# check the negative spike times
spike_vector = sorting.to_spike_vector()
spike_times_samples = spike_vector['sample_index']
print(spike_times_samples[:10]) # first x spike times.
zm711 commented 7 months ago

@alirza67, sorry that's annoying.

The negative spike is completely a bug on the Kilosort side. We can't change that, so we have remove_excess_spikes` to patch for those types of sorter errors. It seems that they fixed it which is good.

Could you do

from spikeinterface.core.waveform_tools import has_exceeding_spikes

has_exceeding_spikes(multirecording, sorting)
has_exceeding_spikes(single_rec, sorting)

And let me know what those two function calls return?

Sachuriga commented 7 months ago

@zm711 Hi, thanks for working on this, it solve on my data! but I not export it to phy yet, I will let you know how it works.

zm711 commented 7 months ago

@sachuriga283, awesome! I assumed that fixing it for yours would work all over the place, but maybe different data sets need more help. I'll close this when I'm done helping @alirza67. If you run into a new issue not related to negative spikes just open a new issue :)

alirza67 commented 7 months ago

Thanks @zm711 , using it with multirecording cant find them!

has_exceeding_spikes(multirecording, sorting)

False

has_exceeding_spikes(single_rec, sorting)

True

zm711 commented 7 months ago

Now we are cooking.

    spike_vector = sorting.to_spike_vector()
    for segment_index in range(recording.get_num_segments()):
        start_seg_ind, end_seg_ind = np.searchsorted(spike_vector["segment_index"], [segment_index, segment_index + 1])
        spike_vector_seg = spike_vector[start_seg_ind:end_seg_ind]
        if len(spike_vector_seg) > 0:
            print(spike_vector_seg["sample_index"][0])
            if spike_vector_seg["sample_index"][-1] > recording.get_num_samples(segment_index=segment_index) - 1:
                print(True)
            if spike_vector_seg["sample_index"][0] < 0:
                print(True)
    print(False)

this is how has_exceeding_spikes works. Could you run this and just do recording=multirecording and then print(spike_vector_seg["sample_index"][0]) for each iteration. If that is too confusing I would probably need you to send me data so I could see the exact issue from my computer instead. I'm wondering if with the multirecording the spike times aren't sorted such that we can just check the first sample.

edited the script for what I want rather than what the function had.

alirza67 commented 7 months ago

@zm711, The data is quite big.

Info for multi recording: ConcatenateSegmentRecording: 384 channels - 30.0kHz - 1 segments - 64,713,960 samples 2,157.13s (35.95 minutes) - int16 dtype - 46.29 GiB

And the output for your provided code is:

-2 True False

Edit: As the concatenate_recordings() mimics a mono-segment object that concatenates all segments, my multirecording is one segment. Therefore, the loop is executed only once! Edit 2: Running your code on single_rec gives the following output: -2 True True False

zm711 commented 7 months ago

But the issue is that the actual function should return True (I switched the return to a print for our little test). So I'm not sure why when you run the function it is returning False instead....?

Could you do:

curated_sorting = si.remove_excess_spikes(sorting, mutltirecording)
print(curated_sorting)

Could you also do:

si.__version__

just so we can see.

alirza67 commented 7 months ago

@zm711 , I think I figured that out. I was testing the debugged version of the SI in a new conda env, and apparently, I had not not installed jupyter notebook properly on it. Thus, jupyter was running from the base environment with another version of SI (but still v:0.101.0) after reinstalling the jupyter notebook I have:

has_exceeding_spikes(multirecording, sorting)

True

zm711 commented 7 months ago

So @alirza67 if you run the actual removal it works now?

alirza67 commented 7 months ago

So @alirza67 if you run the actual removal it works now?

yes. thanks!

zm711 commented 7 months ago

perfect. Closing this.