SpikeInterface / spikeinterface

A Python-based module for creating flexible and robust spike sorting pipelines.
https://spikeinterface.readthedocs.io
MIT License
531 stars 188 forks source link

Possible issue in whitening procedure? #3510

Open LeMuellerGuy opened 3 weeks ago

LeMuellerGuy commented 3 weeks ago

Hello everyone,

I am currently running spikeinterface 0.101.2 to process some MaxTwo data. I concatenated some recording segments and centered the data, as I assumed it might be relevant to the issue but it doesn't change the outcome. In the process of running SpykingCircus2 I get an error with the attached error trace basically telling me that the whitening procedure of sklearn has a type mismatch. I have checked that my data is properly conditioned (no nan values, non-singular covariance matrix). I have also found out that if I step in and use the debugging console to change the dtype of the data array to float64 (i.e. the requested double type) it works fine. I have also tried my hand at finding the root cause in the sklearn method but didn't have much success. However, I found out that the procedure causing the issue runs fine for what I assume are all but one iteration and then crashes on the last pass, but I was unable to figure out what might cause the difference that makes it crash. Maybe anyone here has an idea what causes this issue?

The error trace:

sorting = run_sorter(sorter.sorter_name, rec.Extractor, folder = folder, verbose = True, remove_existing_folder=overwrite, **params)    
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^    
  File "C:\ProgramData\mambaforge\envs\maxwelltesting\Lib\site-packages\spikeinterface\sorters\runsorter.py", line 199, in run_sorter       
    return run_sorter_local(**common_kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\ProgramData\mambaforge\envs\maxwelltesting\Lib\site-packages\spikeinterface\sorters\runsorter.py", line 261, in run_sorter_local 
    SorterClass.run_from_folder(folder, raise_error, verbose)
  File "C:\ProgramData\mambaforge\envs\maxwelltesting\Lib\site-packages\spikeinterface\sorters\basesorter.py", line 301, in run_from_folder
    raise SpikeSortingError(
spikeinterface.sorters.utils.misc.SpikeSortingError: Spike sorting error trace:
Traceback (most recent call last):
  File "C:\ProgramData\mambaforge\envs\maxwelltesting\Lib\site-packages\spikeinterface\sorters\basesorter.py", line 261, in run_from_folder
    SorterClass._run_from_folder(sorter_output_folder, sorter_params, verbose)
  File "C:\ProgramData\mambaforge\envs\maxwelltesting\Lib\site-packages\spikeinterface\sorters\internal\spyking_circus2.py", line 152, in _run_from_folder
    recording_w = whiten(recording_f, mode="local", radius_um=radius_um, dtype="float32", regularize=True)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\ProgramData\mambaforge\envs\maxwelltesting\Lib\site-packages\spikeinterface\preprocessing\whiten.py", line 84, in __init__
    W, M = compute_whitening_matrix(
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\ProgramData\mambaforge\envs\maxwelltesting\Lib\site-packages\spikeinterface\preprocessing\whiten.py", line 209, in compute_whitening_matrix
    estimator.fit(data)
  File "C:\ProgramData\mambaforge\envs\maxwelltesting\Lib\site-packages\sklearn\base.py", line 1473, in wrapper
    return fit_method(estimator, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\ProgramData\mambaforge\envs\maxwelltesting\Lib\site-packages\sklearn\covariance\_graph_lasso.py", line 1109, in fit
    self.covariance_, self.precision_, self.costs_, self.n_iter_ = _graphical_lasso(
                                                                   ^^^^^^^^^^^^^^^^^
  File "C:\ProgramData\mambaforge\envs\maxwelltesting\Lib\site-packages\sklearn\covariance\_graph_lasso.py", line 139, in _graphical_lasso
    coefs, _, _, _ = cd_fast.enet_coordinate_descent_gram(
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "_cd_fast.pyx", line 569, in sklearn.linear_model._cd_fast.enet_coordinate_descent_gram
ValueError: Buffer dtype mismatch, expected 'const double' but got 'float'
zm711 commented 3 weeks ago

We are having an ongoing discussion on the whitening here. But I'll tag @yger since he runs SC2 in case he has other ideas.

LeMuellerGuy commented 3 weeks ago

Thanks for reaching out, but I think the link you posted may be broken as it links back to this issue for me. Were you referring to this PR? https://github.com/SpikeInterface/spikeinterface/pull/3505. In that case it would still be a good question why it crashes on float32 but runs fine on float64 even though a mock example like this completes fine. Just to add, I'm running sklearn 1.5.2

import numpy as np
from sklearn.covariance import GraphicalLassoCV

arr = np.random.default_rng().normal(size=(20000, 200)).astype(np.float32)
estimator = GraphicalLassoCV(assume_centered=True)
estimator.fit(arr)
yger commented 3 weeks ago

Please update to main. This option to regularize has now been turned off as a default in SC2, because indeed there are some weird cases of failures. I'll dig into that quickly