SpikeInterface / spikeinterface

A Python-based module for creating flexible and robust spike sorting pipelines.
https://spikeinterface.readthedocs.io
MIT License
495 stars 186 forks source link

Can't interpolate traces with dtype('int16') #3404

Closed jackrwaters closed 1 day ago

jackrwaters commented 5 days ago

Hello! I've been having the same error as was seen in [https://github.com/SpikeInterface/spikeinterface/issues/3146#issue-2391280072](this thread). I was happy to switch to float32, as they did, but it seems like this is causing problems further down in our processing chain with Phy. We are running Kilosort4 from SI after performing motion correction—what's the best course of action for us? Should we somehow convert our data back to int16 prior to running kilosort?

Traceback (most recent call last): File "----------", line 122, in rec_corrected, motion_info = si.correct_motion(rec, preset='nonrigid_accurate', interpolate_motion_kwargs={'border_mode':'force_extrapolate'},folder=result_folder, output_motion_info=True) File "---------/miniconda3/envs/kilosort/lib/python3.9/site-packages/spikeinterface/preprocessing/motion.py", line 433, in correct_motion recording_corrected = InterpolateMotionRecording(recording, motion, **interpolate_motion_kwargs) File "-----------/miniconda3/envs/kilosort/lib/python3.9/site-packages/spikeinterface/sortingcomponents/motion/motion_interpolation.py", line 346, in init raise ValueError(f"Can't interpolate traces of recording with non-floating dtype={recording.dtype=}.") ValueError: Can't interpolate traces of recording with non-floating dtype=recording.dtype=dtype('int16'). srun: error: gpu-n40: task 0: Exited with exit code 1

zm711 commented 5 days ago

could you post your full spikeinterface script? Kilosort 1-3 require int16. I think KS4 can handle not int16, but still prefers int16. So our wrappers should handle this.

jackrwaters commented 5 days ago
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import shutil
import spikeinterface.full as si
import spikeinterface.sorters as ss
from spikeinterface.sorters import run_sorter
from spikeinterface.sortingcomponents.peak_localization import localize_peaks

from spikeinterface.sortingcomponents.peak_detection import detect_peaks

folder='----------'

global_job_kwargs = dict(n_jobs=40, chunk_duration='1s', progress_bar=True)
si.set_global_job_kwargs(**global_job_kwargs)

raw_rec = si.read_spikeglx(folder_path=folder, stream_id = "imec0.ap")
raw_rec

#rec = preprocess_chain(raw_rec)
rec = si.bandpass_filter(raw_rec, freq_min=300., freq_max=6000, dtype = 'int16') #previously changed dtype to float32, resolving issue
bad_channels, channel_labels = si.detect_bad_channels(rec)
print('bad_channel_ids', bad_channels)
rec = si.phase_shift(recording=rec)
rec = si.common_reference(rec, reference='global', operator = 'median') #instead of highpass_spatial_filter in IBL

job_kwargs = dict(n_jobs=40, chunk_duration='1s', progress_bar=True)

rec #recording post-filter+phase-shift

# estimate the noise on the scaled traces (microV) or on the raw one (which is in our case int16).
noise_levels_microV = si.get_noise_levels(rec, return_scaled=True)
noise_levels_int16 = si.get_noise_levels(rec, return_scaled=False)

# Detect peaks
job_kwargs = dict(n_jobs=40, chunk_duration='1s', progress_bar=True)
peaks = detect_peaks(rec,  method='locally_exclusive', noise_levels=noise_levels_int16,
                     detect_threshold=5, radius_um=50., **job_kwargs)
peaks

# Localize peaks
peak_locations = localize_peaks(rec, peaks, method='center_of_mass', radius_um=50., **job_kwargs)

rec_corrected, motion_info = si.correct_motion(rec, preset='nonrigid_accurate', interpolate_motion_kwargs={'border_mode':'force_extrapolate'},folder=result_folder, output_motion_info=True)

rec_corrected

preprocess_folder = folder + 'preprocess'
rec_corrected = rec_corrected.save(folder=preprocess_folder, format='binary', dtype='int16', **job_kwargs)

# run kilosort4 without drift correction
params = si.get_default_sorter_params(sorter_name_or_class='kilosort4')
params_kilosort4 = {
    'do_correction': False,
    'bad_channels': None #would need to change if we choose not to delete bad channels
}

sorting = si.run_sorter('kilosort4', rec, output_folder=folder + 'kilosort4_output',
                        docker_image=False, verbose=True, **params_kilosort4)

sorting
alejoe91 commented 5 days ago

As the error suggests, the interpolation part of the correct_motion requires the input to be float. You can simply fix this with this line:

rec = spre.astype(rec, "float")

Honestly I think we should do this by default even if the input is int16 and I find the current behavior pretty annoying! @samuelgarcia what do you think?

alejoe91 commented 4 days ago

@jackrwaters sorry I read too quickly! You could do this: cast to float, run correct motion, and then use the astype to recast to int16.

Can you give it a try?

zm711 commented 4 days ago

Hey Alessio shouldn't we just convert with astype for the user? If KS1-4 prefer int16 (or require it) we should check the dtype and then just run the astype ourselves, no? I can check the wrappers later and add it if we think we should.

alejoe91 commented 2 days ago

Hey Alessio shouldn't we just convert with astype for the user? If KS1-4 prefer int16 (or require it) we should check the dtype and then just run the astype ourselves, no? I can check the wrappers later and add it if we think we should.

That is done for KS1-3, but KS4 accepts floats too so we shouldn't cast to int16 IMO

jiumao2 commented 1 day ago

It might be better to check the dtype at the start of correct_motion? Encountering an error after the time-consuming steps of localize_peaks and estimate_motion can be frustrating, especially when motion_info hasn't been saved to the folder.