SpikeInterface / spikeinterface

A Python-based module for creating flexible and robust spike sorting pipelines.
https://spikeinterface.readthedocs.io
MIT License
495 stars 187 forks source link

Fastest way to get channel ID with the largest amplitude for _each_ spike #2666

Open rat-h opened 5 months ago

rat-h commented 5 months ago

I need to get quickly, preferably without WaveExtractor (we) - just from sorting object, the index of the channel with the highest amplitude of the spike for each spike. we lumped all spikes into units, and spikes in the same unit quite often have different amplitudes in different channels. So, I want to know which channel has the highest amplitude for each spike. Also, we is a pretty slow process, so if it is possible to extract this information directly from sorting - that will be the best.

zm711 commented 5 months ago

@rat-h, unfortunately something like amplitude will also be relatively slow because it requires raw waveform data. So basically you need the waveform extractor to do the process of 1) get the waveforms 2) find the templates 3) determine amplitudes from the templates. At a fundamental level a sorting are just spike timestamps and spike identities (ie knowing the identity or time tells us nothing about the amplitude across any channels). Does that make sense?

That being said the new SortingAnalyzer (if you install from source) seeks to decouple some of the slow steps of the waveform_extractor (for example correlograms do not depend on waveforms so you shouldn't be required to wait for waveforms do that calculation that only needs a sorting). So my advice would be test out the SortingAnalyzer which should be a bit faster than waveform_extractor (and working on making it faster still). Then you would still get the extremum channel on a per unit basis we will have to see if @alejoe91 or @samuelgarcia know of an easy way to do on a per spike basis.

rat-h commented 5 months ago

@zm711 Thank you so much for the detailed explanation. Unfortunately, what I want to do needs information per spike.

I wonder whether I can use detect_peaks from spikeinterface.sortingcomponents.peak_detection to take channel and amplitude but only in samples where a spikesorter marked spike?

Also, please correct me if I'm wrong. A spikesorter does not specify a channel(s) where it finds a spike, just time moments (i.e., sample_index). The localization of actual channels and amplitude is the job of extract_waveforms or SortingAnalyzer. Is that right?

samuelgarcia commented 5 months ago

Hi, I have very few time to explain it today but we have in node_piepeline.py a very undocumented internal machinery a way to do this with this class https://github.com/SpikeInterface/spikeinterface/blob/main/src/spikeinterface/core/node_pipeline.py#L145

In any case you need to go trhough the entirde trace to do this And normaly SortingAnalyzer.compute('spike_amplkitudes", n_jobs=-1) should do the same as fast it can.

rat-h commented 5 months ago

@samuelgarcia Thank you for the suggestion.

There is a bit of the problem with SpikeRetriever. The class documentation says:

channel_from_template: bool, default: True
...
        If False, the max channel is computed for each spike given a radius around the template max channel.

Following these instructions, we should be able just extract spikes like that

sr = si.load_extractor('consensus-spikes.tmp/hdsort-saved')
r = SpikeRetriever(rec,sr,channel_from_template=False,radius_um = 120)
print(r)

But it returns an error

Traceback (most recent call last):
  File "/home/rth/spikeinterface/consensus-spikes.py", line 81, in <module>
    r = SpikeRetriever(rec,sr,channel_from_template=False,
  File "/home/rth/.local/apps/spikes/lib/python3.10/site-packages/spikeinterface/core/node_pipeline.py", line 174, in __init__
    assert extremum_channel_inds is not None, "SpikeRetriever needs the extremum_channel_inds dictionary"
AssertionError: SpikeRetriever needs the extremum_channel_inds dictionary
samuelgarcia commented 5 months ago

Hi I was really tired yesterday when trying to help.

You have the get_template_extremum_channel() to get the max channel ids per units. https://spikeinterface.readthedocs.io/en/latest/api.html#spikeinterface.core.get_template_extremum_channel To get per spikes the alternative is doe the max channel come from the templates (average) which is very fast because deduced from the extremum_channels dict (case 1) or does to find the exact True maximum per spikes which need to go back traces spikes per spikes and find the max (case 2).

SpikeRetriever is for case 2 do not work alone you need to create a pipeline. Which is tedious but efficient.

extremum_channel_inds = get_template_extremum_channel(analyzer, outputs='index')

node0 = SpikeRetriever( recording, sorting,
        channel_from_template=True,
        extremum_channel_inds=extremum_channel_inds,
        radius_um=50,
        peak_sign="neg"
)
node 1 = TODO create a node that use the max channel and return it.

output = run_node_pipeline(
      recording,
      [node0, node1],
)

Why do you need the main channel ? We have machinery to localize spikes on the probe.

rat-h commented 5 months ago

Why do you need the main channel ? We have machinery to localize spikes on the probe.

It is silly, so silly that you may kick me out of the spikeinterface's nice community, well, anywho...

So I have a recording, a very nice recording, with clear spikes and many units. But we have tried almost all slike sorters from spikeinterface with no luck. I want to take spikes from different spikesorters and overlap them. I hope (and It seems it can work) that there should be a lot of spikes from different sorters landing close to each other, so I would use a spike with (say) "3 sorter confidence" or "5 sorter confidence"... Silly, isn't it?

I overlapped spikes and saw that they really land close to each other, but there are many cases where the recording has more than one unit firing in the same time sample (very bursty recordings). So I need "lump" spikes not only by time moment but also by the channel where the peak was observed. Because some sorters do not separate spikes well, I don't want to use unit amplitudes but do it on spike-to-spike bases.

So is there any way to get a vector (sample_index,max_channel) for each sorted spike?

zm711 commented 5 months ago

I think I'm understanding you. Again @samuelgarcia will be better for this type of low-level machinery and if it possible. I'm curious have you tried the multi comparison from the curation module. It has a spiketrain_mode='union', which takes the union of spikes from the two "best" spike sorters in a curation of multiple sorters. I'm wondering if that could help with your issue? it would be like

import spikeinterface.curation as sc
mcmp = sc.compare_multiple_sorters(sorting_list = [s1, s2, s3],
name_list = ['ks', 'tdc', 'sc2'],
spiketrain_mode='union',
other_params # these tune how to look for consensus spikes so read the docs for these
)

Comparison Docstring here

rat-h commented 5 months ago

Hm, @zm711, I did NOT and complitely forgot about this option. This is an excellent suggestion. Let me try :)

rat-h commented 5 months ago

It kind of works, but with some glitches.

So I have a 7 sorting results saved in a directory. The code below reads them and saves a sorting object. Then it extracts waveforms.

import spikeinterface.comparison as sc
names = []
srtes = []
for srt in sorters:
    try: 
        sro = si.load_extractor(f'consensus-spikes.tmp/{srt}-saved')
    except:
        continue
    names.append(srt)
    srtes.append(sro)
mcmp = sc.compare_multiple_sorters(
    sorting_list = srtes,
    name_list    = names,
    spiketrain_mode ='union',
    match_score  = 0.3, #Deafult 0.5
    chance_score = 0.1, #Default 0.1
    delta_time   = 5 #Default 0.4
)
with open("consensus-spikes.pkl",'wb') as fd:
    pkl.dump(mcmp,fd)

asrt = mcmp.get_agreement_sorting(
    minimum_agreement_count=2 #Default 1
)
asrt = asrt.save(folder=f'consensus-spikes.tmp/sorting-saved', overwrite=True)
we = si.extract_waveforms(
            rec, asrt, 'consensus-spikes.tmp/waveforms',
            max_spikes_per_unit=500,
            ms_before=1.5, ms_after=2.5,
            overwrite=True,
            **job_kwargs
        )

Note that delta_time is set to 5 ms, however, I can see spikes on the same channel with the same waveforms closer than a 0.5ms. Screenshot_20240409_121018

Any idea what it may be?

alejoe91 commented 5 months ago

The delta time is only used at the pair comparison level. You can remove these duplicate spikes a posteriori with:

import spikeinterface.curation as scur

sorting_rm_duplicates = scur.remove_duplicated_spikes(asrt)
rat-h commented 5 months ago

It doesn't want!

names = []
srtes = []
for srt in sorters:
    try: 
        sro = si.load_extractor(f'consensus-spikes.tmp/{srt}-saved')
    except:
        continue
    names.append(srt)
    srtes.append(sro)
mcmp = sc.compare_multiple_sorters(
    sorting_list = srtes,
    name_list    = names,
    spiketrain_mode ='union',
    match_score  = 0.3, #Deafult 0.5
    chance_score = 0., #Default 0.1
    delta_time   = 2.5 #Default 0.4
)
with open("consensus-spikes.pkl",'wb') as fd:
    pkl.dump(mcmp,fd)

asrt = mcmp.get_agreement_sorting(
    minimum_agreement_count=2 #Default 1
)
asrt = scur.remove_duplicated_spikes(asrt)
asrt = asrt.save(folder=f'consensus-spikes.tmp/sorting-saved', overwrite=True)
we = si.extract_waveforms(
            rec, asrt, 'consensus-spikes.tmp/waveforms',
            max_spikes_per_unit=500,
            ms_before=1.5, ms_after=2.5,
            overwrite=True,
            **job_kwargs

Screenshot_20240409_125348

The width of the screenshot is about 2ms

alejoe91 commented 5 months ago

ah but this is from 2 different units! The remove_duplicated_spikes cleans up by unit

rat-h commented 5 months ago

Why does compare_multiple_sorters put the same spike into two different units? It seems with so small $\Delta t$ it should merge them?

alejoe91 commented 5 months ago

They might come from different merges