SpikeInterface / spikeinterface

A Python-based module for creating flexible and robust spike sorting pipelines.
https://spikeinterface.readthedocs.io
MIT License
519 stars 186 forks source link

Test export is failing on full tests #2049

Closed h-mayorquin closed 1 year ago

h-mayorquin commented 1 year ago
=================================== FAILURES ===================================
______________________________ test_export_report ______________________________

tmp_path = PosixPath('/tmp/pytest-of-runner/pytest-0/test_export_report0')

    def test_export_report(tmp_path):
        repo = "https://gin.g-node.org/NeuralEnsemble/ephy_testing_data"
        remote_path = "mearec/mearec_test_10s.h5"
        local_path = download_dataset(repo=repo, remote_path=remote_path, local_folder=None)
        recording, sorting = se.read_mearec(local_path)

        waveform_folder = tmp_path / "waveforms"
        output_folder = tmp_path / "mearec_GT_report"

        waveform_extractor = extract_waveforms(recording, sorting, waveform_folder)

        # compute_spike_amplitudes(waveform_extractor)
        # compute_quality_metrics(waveform_extractor)

        job_kwargs = dict(n_jobs=1, chunk_size=30000, progress_bar=True)

>       export_report(waveform_extractor, output_folder, force_computation=True, **job_kwargs)

src/spikeinterface/exporters/tests/test_report.py:29: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
src/spikeinterface/exporters/report.py:68: in export_report
    metrics = compute_quality_metrics(we)
src/spikeinterface/qualitymetrics/quality_metric_calculator.py:231: in compute_quality_metrics
    qmc.run(verbose=verbose, **job_kwargs)
src/spikeinterface/core/waveform_extractor.py:1907: in run
    self._run(**kwargs)
src/spikeinterface/qualitymetrics/quality_metric_calculator.py:118: in _run
    res = func(self.waveform_extractor, unit_ids=non_empty_unit_ids, **params)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

waveform_extractor = WaveformExtractor: 32 channels - 10 units - 1 segments
  before:96 after:128 n_per_units:500
synchrony_sizes = (2, 4, 8)
unit_ids = array(['#0', '#1', '#2', '#3', '#4', '#5', '#6', '#7', '#8', '#9'],
      dtype='<U2')
kwargs = {}
sorting = MEArecSortingExtractor: 10 units - 1 segments - 32.0kHz
  file_path: /home/runner/work/spikeinterface/spikeinterface/test_folder/core/dataset_folder/ephy_testing_data/mearec/mearec_test_10s.h5
spikes = [array([(    10, 7, 0), (   381, 8, 0), (   827, 1, 0), (   973, 4, 0),
       (  1333, 6, 0), (  2120, 7, 0), (  2243...18732, 4, 0), (319954, 8, 0)],
      dtype=[('sample_index', '<i8'), ('unit_index', '<i8'), ('segment_index', '<i8')])]
synchrony_size = 2, all_unit_ids = ['#0', '#1', '#2', '#3', '#4', '#5', ...]
segment_index = 0
spikes_in_segment = array([(    10, 7, 0), (   381, 8, 0), (   827, 1, 0), (   973, 4, 0),
       (  1333, 6, 0), (  2120, 7, 0), (  2243,...318732, 4, 0), (319954, 8, 0)],
      dtype=[('sample_index', '<i8'), ('unit_index', '<i8'), ('segment_index', '<i8')])
unique_spike_index = array([    10,    381,    827,    973,   1333,   2120,   2243,   2333,
         3397,   4261,   4576,   5190,   5197, ... 315569, 315880,
       315881, 316011, 316155, 31[656](https://github.com/SpikeInterface/spikeinterface/actions/runs/6325644167/job/17179889959#step:10:657)9, 31[681](https://github.com/SpikeInterface/spikeinterface/actions/runs/6325644167/job/17179889959#step:10:682)5, 317756, 318191, 318324,
       318460, 318732, 319954])

    def compute_synchrony_metrics(waveform_extractor, synchrony_sizes=(2, 4, 8), unit_ids=None, **kwargs):
        """Compute synchrony metrics. Synchrony metrics represent the rate of occurrences of
        "synchrony_size" spikes at the exact same sample index.

        Parameters
        ----------
        waveform_extractor : WaveformExtractor
            The waveform extractor object.
        synchrony_sizes : list or tuple, default: (2, 4, 8)
            The synchrony sizes to compute.
        unit_ids : list or None, default: None
            List of unit ids to compute the synchrony metrics. If None, all units are used.

        Returns
        -------
        sync_spike_{X} : dict
            The synchrony metric for synchrony size X.
            Returns are as many as synchrony_sizes.

        References
        ----------
        Based on concepts described in [Gruen]_
        This code was adapted from `Elephant - Electrophysiology Analysis Toolkit <https://github.com/NeuralEnsemble/elephant/blob/master/elephant/spike_train_synchrony.py#L245>`_
        """
        assert min(synchrony_sizes) > 1, "Synchrony sizes must be greater than 1"
        spike_counts = waveform_extractor.sorting.count_num_spikes_per_unit()
        sorting = waveform_extractor.sorting
        spikes = sorting.to_spike_vector(concatenated=False)

        if unit_ids is None:
            unit_ids = sorting.unit_ids

        # Pre-allocate synchrony counts
        synchrony_counts = {}
        for synchrony_size in synchrony_sizes:
            synchrony_counts[synchrony_size] = np.zeros(len(waveform_extractor.unit_ids), dtype=np.int64)

        all_unit_ids = list(sorting.unit_ids)
        for segment_index in range(sorting.get_num_segments()):
            spikes_in_segment = spikes[segment_index]

            # we compute just by counting the occurrence of each sample_index
            unique_spike_index, complexity = np.unique(spikes_in_segment["sample_index"], return_counts=True)

            # add counts for this segment
            for unit_id in unit_ids:
                unit_index = all_unit_ids.index(unit_id)
                spikes_per_unit = spikes_in_segment[spikes_in_segment["unit_index"] == unit_index]
                # some segments/units might have no spikes
                if len(spikes_per_unit) == 0:
                    continue
                spike_complexity = complexity[np.isin(unique_spike_index, spikes_per_unit["sample_index"])]
                for synchrony_size in synchrony_sizes:
>                   synchrony_counts[synchrony_size][unit_id] += np.count_nonzero(spike_complexity >= synchrony_size)
E                   IndexError: only integers, slices (`:`), ellipsis (`...`), numpy.newaxis (`None`) and integer or boolean arrays are valid indices

Not sure if related to the last changes which is the only association that I have:

https://github.com/SpikeInterface/spikeinterface/pull/2041

zm711 commented 1 year ago

@h-mayorquin @alejoe91 added a fix in the PR (the same test failed while working on the PR commit 957a169). I would try re-running the tests (or at this point wait for the cron) to see if the fix he added (and is now merged into main) fixes the full-tests.

alejoe91 commented 1 year ago

Thanks Ramon. It was actually introduced here #1981 and already fixed in #2041 :)

h-mayorquin commented 1 year ago

All right, then the tests should work tomorrow when the cron job runs and we close this.

zm711 commented 1 year ago
   def fit_collision(
        collision,
        traces_with_margin,
        start_frame,
        end_frame,
        left,
        right,
        nbefore,
        all_templates,
        unit_inds_to_channel_indices,
        cut_out_before,
        cut_out_after,
    ):
        """
        Compute the best fit for a collision between a spike and its overlapping spikes.
        The function first cuts out the traces around the spike and its overlapping spikes, then
        fits a multi-linear regression model to the traces using the centered templates as predictors.

        Parameters
        ----------
        collision: np.ndarray
            A numpy array of shape (n_colliding_spikes, ) containing the colliding spikes (spike_dtype).
        traces_with_margin: np.ndarray
            A numpy array of shape (n_samples, n_channels) containing the traces with a margin.
        start_frame: int
            The start frame of the chunk for traces_with_margin.
        end_frame: int
            The end frame of the chunk for traces_with_margin.
        left: int
            The left margin of the chunk for traces_with_margin.
        right: int
            The right margin of the chunk for traces_with_margin.
        nbefore: int
            The number of samples before the spike to consider for the fit.
        all_templates: np.ndarray
            A numpy array of shape (n_units, n_samples, n_channels) containing the templates.
        unit_inds_to_channel_indices: dict
            A dictionary mapping unit indices to channel indices.
        cut_out_before: int
            The number of samples to cut out before the spike.
        cut_out_after: int
            The number of samples to cut out after the spike.

        Returns
        -------
        np.ndarray
            The fitted scaling factors for the colliding spikes.
        """
        from sklearn.linear_model import LinearRegression

        # make center of the spike externally
        sample_first_centered = np.min(collision["sample_index"]) - (start_frame - left)
        sample_last_centered = np.max(collision["sample_index"]) - (start_frame - left)

        # construct sparsity as union between units' sparsity
        sparse_indices = np.array([], dtype="int")
        for spike in collision:
            sparse_indices_i = unit_inds_to_channel_indices[spike["unit_index"]]
            sparse_indices = np.union1d(sparse_indices, sparse_indices_i)

        local_waveform_start = max(0, sample_first_centered - cut_out_before)
        local_waveform_end = min(traces_with_margin.shape[0], sample_last_centered + cut_out_after)
        local_waveform = traces_with_margin[local_waveform_start:local_waveform_end, sparse_indices]

        y = local_waveform.T.flatten()
        X = np.zeros((len(y), len(collision)))
        for i, spike in enumerate(collision):
            full_template = np.zeros_like(local_waveform)
            # center wrt cutout traces
            sample_centered = spike["sample_index"] - (start_frame - left) - local_waveform_start
            template = all_templates[spike["unit_index"]][:, sparse_indices]
            template_cut = template[nbefore - cut_out_before : nbefore + cut_out_after]
            # deal with borders
            if sample_centered - cut_out_before < 0:
                full_template[: sample_centered + cut_out_after] = template_cut[cut_out_before - sample_centered :]
            elif sample_centered + cut_out_after > end_frame + right:
                full_template[sample_centered - cut_out_before :] = template_cut[: -cut_out_after - (end_frame + right)]
            else:
>               full_template[sample_centered - cut_out_before : sample_centered + cut_out_after] = template_cut
E               ValueError: could not broadcast input array from shape (210,4) into shape (157,4)

And also

    def main_function(cls, traces, d):
        templates = d["templates"]
        num_templates = d["num_templates"]
        num_channels = d["num_channels"]
        num_samples = d["num_samples"]
        overlaps = d["overlaps"]
        norms = d["norms"]
        nbefore = d["nbefore"]
        nafter = d["nafter"]
        omp_tol = np.finfo(np.float32).eps
        num_samples = d["nafter"] + d["nbefore"]
        neighbor_window = num_samples - 1
        min_amplitude, max_amplitude = d["amplitudes"]
        ignored_ids = d["ignored_ids"]
        stop_criteria = d["stop_criteria"][:, np.newaxis]
        vicinity = d["vicinity"]
        rank = d["rank"]

        num_timesteps = len(traces)

        num_peaks = num_timesteps - num_samples + 1
        conv_shape = (num_templates, num_peaks)
        scalar_products = np.zeros(conv_shape, dtype=np.float32)

        # Filter using overlap-and-add convolution
        if len(ignored_ids) > 0:
            mask = ~np.isin(np.arange(num_templates), ignored_ids)
>           spatially_filtered_data = np.matmul(d["spatial"][:, mask, :], traces.T[np.newaxis, :, :])
E           ValueError: matmul: Input operand 1 has a mismatch in its core dimension 0, with gufunc signature (n?,k),(k,m?)->(n?,m?) (size 32 is different from 4)
zm711 commented 1 year ago

@h-mayorquin I think this is done now, right?

h-mayorquin commented 1 year ago

Yes, let's close this