SpikeInterface / spikeinterface

A Python-based module for creating flexible and robust spike sorting pipelines.
https://spikeinterface.readthedocs.io
MIT License
495 stars 187 forks source link

Question about interpolate with motion of a different sampling rate than recording #3328

Open DaohanZhang opened 3 weeks ago

DaohanZhang commented 3 weeks ago

Hey, as the document mentioned, some drift-correcting methods like dredge_lfp require preprocessing steps, including resampling. I'm curious about whether I need to resample the motion data back to match the original recordings. If I need to resample back, we use low-frequency (LF) recordings from Neuropixels to obtain motion data sometimes and apply it to action-potential (AP) recordings, but the sampling ratio isn't an integer (not exactly 12) due to timing differences. If I don't resample the motion data back, I encounter an error when saving recordings to Zarr files. The index 208455 corresponds to the timestep of the motion data, not the recording itself.

import numpy as np
import matplotlib.pyplot as plt

import spikeinterface.full as si  # import core only
import spikeinterface.extractors as se
import spikeinterface.preprocessing as spre
import spikeinterface.sorters as ss
import spikeinterface.postprocessing as spost
import spikeinterface.qualitymetrics as sqm
import spikeinterface.comparison as sc
import spikeinterface.exporters as sexp
import spikeinterface.curation as scur
import spikeinterface.widgets as sw

lfpraw = se.read_spikeglx("/home/zhangdaohan20h/public_data/NPX_examples/Pt01/", load_sync_channel=False, 
                          stream_id="imec0.lf")
## the file named Pt02.imec0.ap.bin & Pt02.imec0.ap.meta
lfpraw
# convert to floating point
lfprec = si.astype(lfpraw, np.float32)
# ensure depth order
lfprec = si.depth_order(lfprec)

cutoff_um = 8000
if cutoff_um is not None:
    geom = lfprec.get_channel_locations()
    lfprec = lfprec.remove_channels(lfprec.channel_ids[geom[:, 1] > cutoff_um])

# bandpass filter
# we do an aggressive one since we plan to downsample
lfprec = si.bandpass_filter(
    lfprec,
    freq_min=0.5,
    freq_max=250,
    margin_ms=1000,
    filter_order=3,
    dtype="float32",
    add_reflect_padding=True,
)
# fancy bad channels detection and removal from the International Brain Lab
bad_chans, labels = si.detect_bad_channels(lfprec, psd_hf_threshold=1.4, num_random_chunks=100, seed=0)
print("Found bad channels", bad_chans)
lfprec = lfprec.remove_channels(bad_chans)
# correct for ADC sample shifts
lfprec = si.phase_shift(lfprec)
# common median reference
lfprec = si.common_reference(lfprec)
# downsample to 250Hz
lfprec = si.resample(lfprec, 250, margin_ms=1000)
# spatial filters: second derivative and averageing same-depth channels
lfprec = si.directional_derivative(lfprec, order=2, edge_order=1)
lfprec = si.average_across_direction(lfprec)

from spikeinterface.sortingcomponents.motion import estimate_motion
motion = estimate_motion(lfprec, method='dredge_lfp', rigid=False, progress_bar=True, max_disp_um=1000)

rec = se.read_spikeglx("/home/zhangdaohan20h/public_data/NPX_examples/Pt01", 
                        load_sync_channel=False, stream_id="imec0.ap")

from spikeinterface.sortingcomponents.motion import interpolate_motion
rec = si.astype(rec, np.float32)
print(motion.dim)
rec = interpolate_motion(rec, motion, border_mode='remove_channels', 
                             spatial_interpolation_method='kriging', sigma_um=20.0, p=1, 
                             num_closest=3, interpolation_time_bin_centers_s=None, 
                             interpolation_time_bin_size_s=None, dtype=None)

rec = si.bandpass_filter(rec)
rec = si.common_reference(rec)
# Save the recording  
rec.save(folder='recording', format='zarr',overwrite=True, engine='joblib', engine_kwargs={"n_jobs": 20},)  #binary_folder

write_zarr_recording n_jobs=1 - samples_per_chunk=30,000 - chunk_memory=8.01 MiB - total_memory=8.01 MiB - chunk_duration=1.00s write_zarr_recording: 100%|##############################################################9| 833/834 [2:19:45<00:10, 10.07s/it] Traceback (most recent call last): File "/share/home/zhangdaohan20h/dredge-main/notebook/prepro.py", line 111, in rec.save(folder='recording', format='zarr',overwrite=True, engine='joblib', engine_kwargs={"n_jobs": 20},) #binary_folder File "/home/zhangdaohan20h/.conda/envs/kilosort4/lib/python3.9/site-packages/spikeinterface/core/base.py", line 861, in save loaded_extractor = self.save_to_zarr(kwargs) File "/home/zhangdaohan20h/.conda/envs/kilosort4/lib/python3.9/site-packages/spikeinterface/core/base.py", line 1065, in save_to_zarr cached = self._save(format="zarr", verbose=verbose, save_kwargs) File "/home/zhangdaohan20h/.conda/envs/kilosort4/lib/python3.9/site-packages/spikeinterface/core/baserecording.py", line 576, in _save ZarrRecordingExtractor.write_recording( File "/home/zhangdaohan20h/.conda/envs/kilosort4/lib/python3.9/site-packages/spikeinterface/core/zarrextractors.py", line 118, in write_recording add_recording_to_zarr_group(recording, zarr_root, **kwargs) File "/home/zhangdaohan20h/.conda/envs/kilosort4/lib/python3.9/site-packages/spikeinterface/core/zarrextractors.py", line 382, in add_recording_to_zarr_group add_traces_to_zarr( File "/home/zhangdaohan20h/.conda/envs/kilosort4/lib/python3.9/site-packages/spikeinterface/core/zarrextractors.py", line 509, in add_traces_to_zarr executor.run() File "/home/zhangdaohan20h/.conda/envs/kilosort4/lib/python3.9/site-packages/spikeinterface/core/job_tools.py", line 405, in run res = self.func(segment_index, frame_start, frame_stop, worker_ctx) File "/home/zhangdaohan20h/.conda/envs/kilosort4/lib/python3.9/site-packages/spikeinterface/core/zarrextractors.py", line 537, in _write_zarr_chunk traces = recording.get_traces( File "/home/zhangdaohan20h/.conda/envs/kilosort4/lib/python3.9/site-packages/spikeinterface/core/baserecording.py", line 342, in get_traces traces = rs.get_traces(start_frame=start_frame, end_frame=end_frame, channel_indices=channel_indices) File "/home/zhangdaohan20h/.conda/envs/kilosort4/lib/python3.9/site-packages/spikeinterface/preprocessing/common_reference.py", line 183, in get_traces traces = self.parent_recording_segment.get_traces(start_frame, end_frame, slice(None)) File "/home/zhangdaohan20h/.conda/envs/kilosort4/lib/python3.9/site-packages/spikeinterface/preprocessing/filter.py", line 137, in get_traces traces_chunk, left_margin, right_margin = get_chunk_with_margin( File "/home/zhangdaohan20h/.conda/envs/kilosort4/lib/python3.9/site-packages/spikeinterface/core/recording_tools.py", line 737, in get_chunk_with_margin traces_chunk = rec_segment.get_traces( File "/home/zhangdaohan20h/.conda/envs/kilosort4/lib/python3.9/site-packages/spikeinterface/sortingcomponents/motion/motion_interpolation.py", line 462, in get_traces traces = interpolate_motion_on_traces( File "/home/zhangdaohan20h/.conda/envs/kilosort4/lib/python3.9/site-packages/spikeinterface/sortingcomponents/motion/motion_interpolation.py", line 139, in interpolate_motion_on_traces bin_time = time_bins[bin_ind] IndexError: index 208455 is out of bounds for axis 0 with size 208455

If i need to resample back, i need to use non-integer resample rate to match the timestep of motion with recording exactly. It will raise the error like this:

motion = si.resample(motion, resample_rate = len(rec.get_times(0))/len(motion.displacement[0]), margin_ms=1000)

AssertionError Traceback (most recent call last) Cell In[44], line 1 ----> 1 motion = si.resample(motion, resample_rate = len(rec.get_times(0))/len(motion.displacement[0]), margin_ms=1000)

File ~/.conda/envs/kilosort4/lib/python3.9/site-packages/spikeinterface/preprocessing/resample.py:55, in ResampleRecording.init(self, recording, resample_rate, margin_ms, dtype, skip_checks) 45 def init( 46 self, 47 recording, (...) 52 ): 53 # Floating point resampling rates can lead to unexpected results, avoid actively 54 msg = "Non integer resampling rates can lead to unexpected results." ---> 55 assert isinstance(resample_rate, (int, np.integer)), msg 56 # Original sampling frequency 57 self._orig_samp_freq = recording.get_sampling_frequency()

AssertionError: Non integer resampling rates can lead to unexpected results.

In the old version, using me_rigid_upsampled was effective with DREDge. However, in the current version, it will raise an error because it requires a motion object as input instead of the original displacement numpy array. I documented this issue in the DREDge GitHub repository (https://github.com/evarol/dredge/issues/1). I mention this as an example of a more practical resampling function.

from dredge.dredge_ap import register
from dredge.dredge_lfp import register_online_lfp
import dredge.motion_util as mu
me_rigid, extra_info_rigid = register_online_lfp(lfprec, max_disp_um=1000)
me_rigid_upsampled = mu.resample_to_new_time_bins(me_rigid, rec.get_times(0))
rec = mu.get_interpolated_recording(me_rigid_upsampled, rec)

AttributeError Traceback (most recent call last) Cell In[15], line 1 ----> 1 rec = mu.get_interpolated_recording(me_rigid_upsampled, rec) 2 rec

File /share/home/zhangdaohan20h/dredge-main/python/dredge/motion_util.py:400, in get_interpolated_recording(motion_est, recording, border_mode) 397 assert displacement.ndim == 2 and displacement.shape[0] == 2 399 # now we can use correct_motion --> 400 rec_interpolated = InterpolateMotionRecording( 401 rec, 402 displacement.T, 403 temporal_bins, 404 spatial_bins, 405 #border_mode=border_mode, updated by zdh 406 ) 407 return rec_interpolated

File ~/.conda/envs/kilosort4/lib/python3.9/site-packages/spikeinterface/sortingcomponents/motion/motion_interpolation.py:326, in InterpolateMotionRecording.init(self, recording, motion, border_mode, spatial_interpolation_method, sigma_um, p, num_closest, interpolation_time_bin_centers_s, interpolation_time_bin_size_s, dtype, spatial_interpolation_kwargs) 301 channel_locations = recording.get_channel_locations() 302 ''' 303 import json #This is my testing codes 304 #print(motion.keys())# (...) 324 print(f'Dimension attribute: {motion.dim}') 325 ''' --> 326 assert channel_locations.ndim >= motion.dim, ( 327 f"'direction' {motion.direction} not available. " 328 f"Channel locations have {channel_locations.ndim} dimensions." 329 ) 330 spatial_interpolation_kwargs = dict( 331 sigma_um=sigma_um, p=p, num_closest=num_closest, spatial_interpolation_kwargs 332 ) 333 if border_mode == "remove_channels":

AttributeError: 'numpy.ndarray' object has no attribute 'dim'

zm711 commented 3 weeks ago

I'll loop back in @cwindolf and @samuelgarcia for this motion stuff. @cwindolf feel free to add a bug tag if you think this is a bug. Since I don't know the motion correction stuff at all I don't want to use the wrong tag :)

cwindolf commented 3 weeks ago

Hi Daohan, thanks for your comments! And thanks for the ping @zm711

I don't immediately have all of the answers, but let me just organize your points.

  1. There is a bug in interpolate_motion causing it to crash with an index error when LFP is used to correct AP. Needs fixing, but I'm not sure the cause -- @samuelgarcia I'm not sure whether this implementation was tested to work, and if so where that result is?
  2. evarol/dredge is currently broken, because I haven't handled the Motion object stuff there. To fix the "no attribute .dim" error you're getting there, I'll need to add a conversion to Motion, which will be straightforward. I'll address that over there in https://github.com/evarol/dredge/issues/1 -- thanks for opening that issue.
  3. Resampling the motion estimate. si.resample() is designed to work recordings and not Motion objects, so I would not expect that line of code to work. So, maybe spikeinterface should have a resample_motion function that works like the resample_to_new_time_bins() function I have in https://github.com/evarol/dredge. Note that this is independent from (1) -- interpolate_motion should work regardless of what time bins the Motion object has.

As an aside, @DaohanZhang when we do get this working you might want to apply your AP band preprocessing (filtering + common_reference in this case) before interpolation, which will likely be helpful and is more standard.

DaohanZhang commented 3 weeks ago

Hi! @cwindolf Thanks for your reorganization! My previous preanalysis pipeline is using dredge to get the motion displacement, resample the displacement to the timestep of the recording, and interpolate recording.

from dredge.dredge_ap import register
from dredge.dredge_lfp import register_online_lfp
import dredge.motion_util as mu
me_rigid, extra_info_rigid = register_online_lfp(lfprec, max_disp_um=1000)
me_rigid_upsampled = mu.resample_to_new_time_bins(me_rigid, rec.get_times(0))
rec = mu.get_interpolated_recording(me_rigid_upsampled, rec)

But after I upgraded to v0.101.0, I found the bug in dredge which prevent me from using this pipeline. https://github.com/evarol/dredge/issues/1 I turned to drege_lfp in si, but i did not find motion resample functions.

from spikeinterface.sortingcomponents.motion import estimate_motion
motion = estimate_motion(lfprec, method='dredge_lfp', rigid=False, progress_bar=True, max_disp_um=1000)

Later I wondered is there possibility that the motion class could automately resample to fit the recording sample rate, as motion.check_properties will return sampling rate of motion, and I ran the interpolate_motion without resampling back with no error raised. But the index error was raised when i managed to save the recordings after interpolation. It seems that motion resmapling does needed.

rec = interpolate_motion(rec, motion, border_mode='remove_channels', 
                             spatial_interpolation_method='kriging', sigma_um=20.0, p=1, 
                             num_closest=3, interpolation_time_bin_centers_s=None, 
                             interpolation_time_bin_size_s=None, dtype=None)

rec = si.bandpass_filter(rec)
rec = si.common_reference(rec)
# Save the recording  
rec.save(folder='recording', format='zarr',overwrite=True, engine='joblib', engine_kwargs={"n_jobs": 20},)

For this scenario, the quickest and simplest solution would be to fix the bug in the dredge function. I’ll be on standby for updates! However, if it's necessary to replace the entire dredge function, we might need to implement a motion_resampling function instead.

Additionally, thanks for your suggestion related to the sequence of preprocessing and interpolation. I' ll try it after the bug in DREDge got fixed.