SpikeInterface / spikeinterface

A Python-based module for creating flexible and robust spike sorting pipelines.
https://spikeinterface.readthedocs.io
MIT License
481 stars 183 forks source link

Integration of decoder for automated noise, mua, and sua curation #2479

Open anoushkajain opened 5 months ago

anoushkajain commented 5 months ago

Hello Spikeinterface team,

We’ve been working on extracting single-neuron activity from Neuropixels data, which often requires extensive manual evaluation of spike clusters.

To streamline this process, we have developed a machine-learning pipeline that employs quality metrics and human labels to classify units into Single-Unit Activity (SUA), Multi-Unit Activity (MUA), and Noise, to reduce curation time and improve the reproducibility of the results.

I have attached a sorting view that shows the output of the decoder: https://figurl.org/f?v=gs://figurl/sortingview-11&d=sha1://f709f74add515e23062e3952076602e40d3e86ce

As previously discussed with Alessio, we want to integrate this approach into the curation module of SpikeInterface. I am initiating this issue to start our discussion on the next steps towards this integration.

Antho2422 commented 5 months ago

@anoushkajain Hi, looks great ! We are also working with my team on a similar process based on human labeling on our data. Can you describe a bit more your process ?

Cheers, Anthony

anoushkajain commented 5 months ago

@Antho2422 Hi, we are currently working on putting it in a preprint and we will release it soon.

alejoe91 commented 5 months ago

A possible solution could be to have in the spikeinterface.curation a ModelBasedCuration abstract class that would look as follows:

class ModelBasedCuration:
    requirements = None

    def __init__(required_metrics,  model_path):
        self.model_path = model_path
        self.required_metrics = required_metrics
        self.instantiated_models = dict()

    def instantiate_model(self):
        raise NotImplementedError

    def apply(self, sorting_analyzer, return_probabilities=True):
        raise NotImplementedError

    def compute_missing_metrics(self):
        raise NotImplementedError

    @staticmethod
    def check_installed():
        if self.requirements is not None:
            for req in requirements: # check installation

    # optional
    @staticmethod
    def train(...):
        raise NotImplementedError

Then the pretrained model could just inherit from this:

class PreTrainedSklearnCuration(ModelBasedCuration):

    requirements = ["sklearn", "xgboost", ...]

    def __init__(required_metrics,  model_path):
        ModelBasedCuration.__init__(required_metrics,  model_path)

    def load_model(self):
        self.instantiated_models["noise-neuron"] = pickle.load(model_path / "noise-neuron.pkl")
        self.instantiated_models["sua-mua"] = pickle.load(model_path / "sua-mua.pkl")

    def apply(self, sorting_analyzer, return_probabilities=True):
        compute_missing_metrics(sorting_analyzer)
        metrics = ...
        noise_neuron = self.instantiated_models["noise-neuron"].predict(metrics)
        ...
        if return_probabilities:
            return predictions, prediction_probabilities
        else:
            return predictions

How does that sound? @anoushkajain @Antho2422 @samuelgarcia

anoushkajain commented 5 months ago

Yes, looks good. I will start working on it.

zm711 commented 1 month ago

@anoushkajain do you want to link your PR to this issue?