MouseLand / Kilosort

Fast spike sorting with drift correction
https://kilosort.readthedocs.io/en/latest/
GNU General Public License v3.0
488 stars 247 forks source link

BUG: The expanded size of the tensor (60122) must match the existing size (60071) at non-singleton dimension 1. Target sizes: [112, 60122]. Tensor sizes: [112, 60071] #739

Closed JuanPimientoCaicedo closed 3 months ago

JuanPimientoCaicedo commented 4 months ago

Describe the issue:

Hello, guys.

I am using spikeinterface to run Kilosort and I am encountering this error. I tried using kilosort 4.0.13 as well and found the same results. I opened an issue in the Spikeinterface repo as well: https://github.com/SpikeInterface/spikeinterface/issues/3183.

I suspect it might be related the compatibility between python and CUDA, but I am following the kilosort installation instructions with python 3.9 and CUDA 11.8.

Reproduce the bug:

No response

Error message:

INFO:kilosort.io:========================================
INFO:kilosort.io:Loading recording with SpikeInterface...
INFO:kilosort.io:number of samples: 1800010
INFO:kilosort.io:number of channels: 112
INFO:kilosort.io:numbef of segments: 1
INFO:kilosort.io:sampling rate: 30000.18060200669
INFO:kilosort.io:dtype: int16
INFO:kilosort.io:========================================
INFO:kilosort.run_kilosort: 
INFO:kilosort.run_kilosort:Computing preprocessing variables.
INFO:kilosort.run_kilosort:----------------------------------------
INFO:kilosort.run_kilosort:Preprocessing filters computed in  7.90s; total  7.90s
INFO:kilosort.run_kilosort: 
INFO:kilosort.run_kilosort:Computing drift correction.
INFO:kilosort.run_kilosort:----------------------------------------
INFO:kilosort.datashift:nblocks = 0, skipping drift correction
INFO:kilosort.run_kilosort:drift computed in  0.00s; total  7.90s
INFO:kilosort.run_kilosort: 
INFO:kilosort.run_kilosort:Extracting spikes using templates
INFO:kilosort.run_kilosort:----------------------------------------
INFO:kilosort.spikedetect:Re-computing universal templates from data.
Skipping drift correction.
 94%|████████████████████████████████████████████████████████████████████████████▋     | 29/31 [02:19<00:09,  4.81s/it]
Error running kilosort4
---------------------------------------------------------------------------
SpikeSortingError                         Traceback (most recent call last)
Cell In[19], line 3
      1 job_kwargs = dict(n_jobs=-1, chunk_duration='1s', progress_bar=True)
----> 3 sorting = si.run_sorter(sorter_name = 'kilosort4', recording = cortex_AP_slice, 
      4                         folder = Path('C:/Users/Juan/Desktop/KS4_output'), 
      5                         do_correction = False, verbose = True, docker_image = False)

File ~\.conda\envs\si_env\lib\site-packages\spikeinterface\sorters\runsorter.py:216, in run_sorter(sorter_name, recording, folder, remove_existing_folder, delete_output_folder, verbose, raise_error, docker_image, singularity_image, delete_container_files, with_output, output_folder, **sorter_params)
    205             raise RuntimeError(
    206                 "The python `spython` package must be installed to "
    207                 "run singularity. Install with `pip install spython`"
    208             )
    210     return run_sorter_container(
    211         container_image=container_image,
    212         mode=mode,
    213         **common_kwargs,
    214     )
--> 216 return run_sorter_local(**common_kwargs)

File ~\.conda\envs\si_env\lib\site-packages\spikeinterface\sorters\runsorter.py:276, in run_sorter_local(sorter_name, recording, folder, remove_existing_folder, delete_output_folder, verbose, raise_error, with_output, output_folder, **sorter_params)
    274 SorterClass.set_params_to_folder(recording, folder, sorter_params, verbose)
    275 SorterClass.setup_recording(recording, folder, verbose=verbose)
--> 276 SorterClass.run_from_folder(folder, raise_error, verbose)
    277 if with_output:
    278     sorting = SorterClass.get_result_from_folder(folder, register_recording=True, sorting_info=True)

File ~\.conda\envs\si_env\lib\site-packages\spikeinterface\sorters\basesorter.py:301, in BaseSorter.run_from_folder(cls, output_folder, raise_error, verbose)
    298         print(f"{sorter_name} run time {run_time:0.2f}s")
    300 if has_error and raise_error:
--> 301     raise SpikeSortingError(
    302         f"Spike sorting error trace:\n{error_log_to_display}\n"
    303         f"Spike sorting failed. You can inspect the runtime trace in {output_folder}/spikeinterface_log.json."
    304     )
    306 return run_time

SpikeSortingError: Spike sorting error trace:
Traceback (most recent call last):
  File "C:\Users\Juan\.conda\envs\si_env\lib\site-packages\spikeinterface\sorters\basesorter.py", line 261, in run_from_folder
    SorterClass._run_from_folder(sorter_output_folder, sorter_params, verbose)
  File "C:\Users\Juan\.conda\envs\si_env\lib\site-packages\spikeinterface\sorters\external\kilosort4.py", line 261, in _run_from_folder
    st, tF, _, _ = detect_spikes(ops, device, bfile, tic0=tic0, progress_bar=progress_bar)
  File "C:\Users\Juan\.conda\envs\si_env\lib\site-packages\kilosort\run_kilosort.py", line 481, in detect_spikes
    st0, tF, ops = spikedetect.run(ops, bfile, device=device, progress_bar=progress_bar)
  File "C:\Users\Juan\.conda\envs\si_env\lib\site-packages\kilosort\spikedetect.py", line 250, in run
    X = bfile.padded_batch_to_torch(ibatch, ops)
  File "C:\Users\Juan\.conda\envs\si_env\lib\site-packages\kilosort\io.py", line 776, in padded_batch_to_torch
    X = super().padded_batch_to_torch(ibatch)
  File "C:\Users\Juan\.conda\envs\si_env\lib\site-packages\kilosort\io.py", line 601, in padded_batch_to_torch
    X[:] = torch.from_numpy(data).to(self.device).float()
RuntimeError: The expanded size of the tensor (60122) must match the existing size (60071) at non-singleton dimension 1.  Target sizes: [112, 60122].  Tensor sizes: [112, 60071]

Version information:

Python version: 3.9.19 | packaged by conda-forge | (main, Mar 20 2024, 12:38:46) [MSC v.1929 64 bit (AMD64)] CUDA is available: True - version 11.8 torch version: 2.3.1 si version: 0.101.0rc0 kilosort version: 4.0.12

jacobpennington commented 4 months ago

Hello, two things: 1) Please try sorting without using SpikeInterface. 2) Can you please upload kilosort4.log from your results directory?

JuanPimientoCaicedo commented 4 months ago
  1. Hi, yes. I tried using kilosort directly and it worked. in both versions 4.0.12 and 4.0.13.
  2. This is an example log, kilosort4.log. Thank you!
JuanPimientoCaicedo commented 4 months ago

Hello @jacobpennington, I am reopening this issue because I found out it is indeed a kilosort problem. I was able to replicate it by using the real sampling rate of the ADC (something that spikeinterface does). Here you have the error:

kilosort.gui.sorter: Kilosort version 4.0.12

kilosort.gui.sorter: Sorting G:\Jon_2024-05-27\imec0.ap.bin

kilosort.gui.sorter: ----------------------------------------

kilosort.run_kilosort:

kilosort.run_kilosort: Computing preprocessing variables.

kilosort.run_kilosort: ----------------------------------------

kilosort.run_kilosort: Preprocessing filters computed in 0.44s; total 0.45s

kilosort.run_kilosort:

kilosort.run_kilosort: Computing drift correction.

kilosort.run_kilosort: ----------------------------------------

kilosort.spikedetect: Re-computing universal templates from data.

C:\Users\Juan.conda\envs\si_env\lib\site-packages\threadpoolctl.py:1214: RuntimeWarning: Found Intel OpenMP ('libiomp') and LLVM OpenMP ('libomp') loaded at the same time. Both libraries are known to be incompatible and this can cause random crashes or deadlocks on Linux when loaded in the same Python program. Using threadpoolctl may cause crashes or deadlocks. For more information and possible workarounds, please see https://github.com/joblib/threadpoolctl/blob/master/multiple_openmp.md

warnings.warn(msg, RuntimeWarning)

0%| | 0/51 [00:00<?, ?it/s]

0%| | 0/51 [00:10<?, ?it/s]

96%|##############################################################################7 | 49/51 [00:34<00:01, 1.42it/s]

Traceback (most recent call last):

File "C:\Users\Juan.conda\envs\si_env\lib\site-packages\kilosort\gui\sorter.py", line 93, in run

ops, bfile, st0 = compute_drift_correction(

File "C:\Users\Juan.conda\envs\si_env\lib\site-packages\kilosort\run_kilosort.py", line 425, in compute_drift_correction

ops, st = datashift.run(ops, bfile, device=device, progress_bar=progress_bar)

File "C:\Users\Juan.conda\envs\si_env\lib\site-packages\kilosort\datashift.py", line 197, in run

st, _, ops = spikedetect.run(ops, bfile, device=device, progress_bar=progress_bar)

File "C:\Users\Juan.conda\envs\si_env\lib\site-packages\kilosort\spikedetect.py", line 250, in run

X = bfile.padded_batch_to_torch(ibatch, ops)

File "C:\Users\Juan.conda\envs\si_env\lib\site-packages\kilosort\io.py", line 776, in padded_batch_to_torch

X = super().padded_batch_to_torch(ibatch)

File "C:\Users\Juan.conda\envs\si_env\lib\site-packages\kilosort\io.py", line 601, in padded_batch_to_torch

X[:] = torch.from_numpy(data).to(self.device).float()

RuntimeError : The expanded size of the tensor (60122) must match the existing size (60079) at non-singleton dimension 1. Target sizes: [385, 60122]. Tensor sizes: [385, 60079]

and the log file:

kilosort4.log

Using the real sampling rate of the ADC is important when you have long recordings, more when you have to align with other data streams.

jacobpennington commented 4 months ago

Thanks, looking into it.

jacobpennington commented 4 months ago

Would it be feasible for you to share one recording to help me reproduce this?

JuanPimientoCaicedo commented 4 months ago

Hi, again. please send a message to my email: juanpimientoca@gmail.com

MariosPanayi commented 4 months ago

Not sure if it is useful, but I have also run into this problem for one of my recordings (run from anaconda terminal with run_kilosort function). After getting the same error twice (with some computer restarting and making sure I had an up to date version of kilosort in between attempts), I have moved on to batch sorting/processing other files with the intention of trying to figure out what went wrong at the end since it is only 1 file so far ( 30+ of my other files had no issues so far). Let me know if I can provide additional details to help debug. brief summary of setup: Running on windows 11 PC with RTX 4070 and AMD threadripper 3990X cpu. Recording: .bin file (int16, 40000Hz sampling rate) with probe file layout containing 15 channels linearly spaced 100 microns apart. Probe and .bin files generated (extracted from .Pl2 file using custom matlab code) in the same way (and of similar file sizes/recording durations) have worked multiple times already without issue.

MariosPanayi commented 4 months ago

Not sure if it is useful, but I have also run into this problem for one of my recordings (run from anaconda terminal with run_kilosort function). After getting the same error twice (with some computer restarting and making sure I had an up to date version of kilosort in between attempts), I have moved on to batch sorting/processing other files with the intention of trying to figure out what went wrong at the end since it is only 1 file so far ( 30+ of my other files had no issues so far). Let me know if I can provide additional details to help debug. brief summary of setup: Running on windows 11 PC with RTX 4070 and AMD threadripper 3990X cpu. Recording: .bin file (int16, 40000Hz sampling rate) with probe file layout containing 15 channels linearly spaced 100 microns apart. Probe and .bin files generated (extracted from .Pl2 file using custom matlab code) in the same way (and of similar file sizes/recording durations) have worked multiple times already without issue.

Update: after trying the trouble file on a couple of different computers and changing some of the parameters, lowering the batch size fixed the error. The error message hinted at a mismatch in dimensions so it seems reasonable. Hope that helps.

jacobpennington commented 4 months ago

Yes that's helpful, thanks!

jacobpennington commented 4 months ago

@MariosPanayi Are you able to upload kilosort4.log from your results directory here, so I can see some additional details? Or can you please at least let me know if you were using the tmax parameter?

jacobpennington commented 4 months ago

@JuanPimientoCaicedo Still working on this, but I did notice I can only reproduce the error when I set tmax. If I leave it as the default tmax = np.inf (i.e. sort the full recording), then I don't get an error and the results look reasonable. That still points to a bug I need to fix, but it looks like you should be able to sort your data in the meantime if you don't specifically need to restrict the sorting to the first 100 seconds.

JuanPimientoCaicedo commented 4 months ago

Thank you, @jacobpennington. I didn't think about sorting the whole recording, this bug happened to me while adapting a new pipeline (that's why I was using small segments, just to test everything).

The bug only happens when the sampling rate is not an integer like 30000 and when time is not np.inf. I continued working with a sampling rate of 30000, a batch size of 60000 and that way I was able to run kilosort without problems while defining tmax.

I will try to move forward with a complete session and let you know if a new error pops out

jacobpennington commented 4 months ago

Just an FYI for anyone that encounters this issue before the fix or with an older version:

I've identified the bug, we just need to decide on the best way to fix it. It happens when divvying up the recording into batches results in a very small last batch (less than nt, 61 samples by default). That's pretty unlikely, but when it does happen it means the second-to-last batch is not the expected size, which causes the error. A simple work-around is to change tmax by a small amount. For example, I was able to sort your data @JuanPimientoCaicedo by using tmax = 100.0015 instead of tmax = 100. If this happens for someone sorting a full dataset, just set tmax to something like 1/100 of a second less than the end of the recording.

MariosPanayi commented 4 months ago

@MariosPanayi Are you able to upload kilosort4.log from your results directory here, so I can see some additional details? Or can you please at least let me know if you were using the tmax parameter?

Sorry for the late reply. It sounds like you have identified the problem. Just in case you still need it here is the log file. I don't specify tmax in my python script, so it reverts to the default 'inf' value:

07-16 16:30 kilosort.run_kilosort INFO Kilosort version 4.0.13 07-16 16:30 kilosort.run_kilosort INFO Sorting D:\MariosSorting\Stage1_TwoOdorNovelAcq\MPNIDA007_Stage1\MP02\MP02_20220304_NoverltyAcq2Odor_Disc9_Disc10_Group_4\MP02_20220304_NoverltyAcq2Odor_Disc9_Disc10_Group_4.bin 07-16 16:30 kilosort.run_kilosort INFO ---------------------------------------- 07-16 16:30 kilosort.run_kilosort INFO Using GPU for PyTorch computations. Specify device to change this. 07-16 16:30 kilosort.run_kilosort DEBUG Initial ops: { 'n_chan_bin': 15, 'fs': 40000.0, 'batch_size': 800000, 'nblocks': 0, 'Th_universal': 9.0, 'Th_learned': 8.0, 'tmin': 0.0, 'tmax': inf, 'nt': 81, 'shift': None, 'scale': None, 'artifact_threshold': inf, 'nskip': 25, 'whitening_range': 16, 'highpass_cutoff': 300, 'binning_depth': 5.0, 'sig_interp': 20.0, 'drift_smoothing': [0.5, 0.5, 0.5], 'nt0min': 26, 'dmin': 50, 'dminx': 25.0, 'min_template_size': 10.0, 'template_sizes': 5, 'nearest_chans': 1, 'nearest_templates': 1, 'max_channel_distance': 1.0, 'templates_from_data': True, 'n_templates': 8, 'n_pcs': 6, 'Th_single_ch': 6.0, 'acg_threshold': 0.2, 'ccg_threshold': 0.25, 'cluster_downsampling': 20, 'x_centers': None, 'duplicate_spike_ms': 0.25, 'data_dir': WindowsPath('D:/MariosSorting/Stage1_TwoOdorNovelAcq/MPNIDA007_Stage1/MP02/MP02_20220304_NoverltyAcq2Odor_Disc9_Disc10_Group_4'), 'duplicate_spike_bins': 10, 'dtype_idx': 3, 'probe_idx': 6, 'device_idx': 0, 'filename': WindowsPath('D:/MariosSorting/Stage1_TwoOdorNovelAcq/MPNIDA007_Stage1/MP02/MP02_20220304_NoverltyAcq2Odor_Disc9_Disc10_Group_4/MP02_20220304_NoverltyAcq2Odor_Disc9_Disc10_Group_4.bin'), 'data_dtype': 'int16', 'do_CAR': True, 'invert_sign': False, 'NTbuff': 800162, 'Nchan': 15, 'torch_device': 'cuda', 'save_preprocessed_copy': False, 'chanMap': array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]), 'xc': array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], dtype=float32), 'yc': array([ 0., 100., 200., 300., 400., 500., 600., 700., 800., 900., 1000., 1100., 1200., 1300., 1400.], dtype=float32), 'kcoords': array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], dtype=float32), 'n_chan': 15}

07-16 16:30 kilosort.run_kilosort INFO
07-16 16:30 kilosort.run_kilosort INFO Computing preprocessing variables. 07-16 16:30 kilosort.run_kilosort INFO ---------------------------------------- 07-16 16:30 kilosort.run_kilosort INFO N samples: 409044447 07-16 16:30 kilosort.run_kilosort INFO N seconds: 10226.111175 07-16 16:30 kilosort.run_kilosort INFO N batches: 512 07-16 16:30 kilosort.run_kilosort INFO Preprocessing filters computed in 1.27s; total 1.28s 07-16 16:30 kilosort.run_kilosort DEBUG hp_filter shape: torch.Size([30122]) 07-16 16:30 kilosort.run_kilosort DEBUG whiten_mat shape: torch.Size([15, 15]) 07-16 16:30 kilosort.run_kilosort INFO
07-16 16:30 kilosort.run_kilosort INFO Computing drift correction. 07-16 16:30 kilosort.run_kilosort INFO ---------------------------------------- 07-16 16:30 kilosort.datashift INFO nblocks = 0, skipping drift correction 07-16 16:30 kilosort.run_kilosort INFO drift computed in 0.00s; total 1.28s 07-16 16:30 kilosort.run_kilosort DEBUG First batch min, max: (-24.985016, 32.066574) 07-16 16:30 kilosort.run_kilosort INFO
07-16 16:30 kilosort.run_kilosort INFO Extracting spikes using templates 07-16 16:30 kilosort.run_kilosort INFO ---------------------------------------- 07-16 16:30 kilosort.spikedetect INFO Re-computing universal templates from data. 07-16 16:30 kilosort.run_kilosort INFO 8916860 spikes extracted in 42.58s; total 43.90s 07-16 16:30 kilosort.run_kilosort DEBUG st0 shape: (8916860, 6) 07-16 16:30 kilosort.run_kilosort DEBUG tF shape: torch.Size([8916860, 1, 6]) 07-16 16:30 kilosort.run_kilosort INFO
07-16 16:30 kilosort.run_kilosort INFO First clustering 07-16 16:30 kilosort.run_kilosort INFO ---------------------------------------- 07-16 17:52 kilosort.run_kilosort INFO 76 clusters found, in 4880.77s; total 4924.68s 07-16 17:52 kilosort.run_kilosort DEBUG clu shape: (8916860,) 07-16 17:52 kilosort.run_kilosort DEBUG Wall shape: torch.Size([76, 15, 6]) 07-16 17:52 kilosort.run_kilosort INFO
07-16 17:52 kilosort.run_kilosort INFO Extracting spikes using cluster waveforms 07-16 17:52 kilosort.run_kilosort INFO ---------------------------------------- 07-16 17:53 kilosort.run_kilosort ERROR Encountered error in run_kilosort: Traceback (most recent call last): File "C:\Users\evanh\anaconda3\envs\kilosort\lib\site-packages\kilosort\run_kilosort.py", line 177, in runkilosort st,tF, , _ = detect_spikes(ops, device, bfile, tic0=tic0, File "C:\Users\evanh\anaconda3\envs\kilosort\lib\site-packages\kilosort\run_kilosort.py", line 546, in detect_spikes st, tF, ops = template_matching.extract(ops, bfile, Wall3, device=device, File "C:\Users\evanh\anaconda3\envs\kilosort\lib\site-packages\kilosort\template_matching.py", line 36, in extract stt, amps, Xres = run_matching(ops, X, U, ctc, device=device) File "C:\Users\evanh\anaconda3\envs\kilosort\lib\site-packages\kilosort\template_matching.py", line 163, in run_matching st[k:k+nsp, 0] = iX[:,0] RuntimeError: The expanded size of the tensor (309) must match the existing size (535) at non-singleton dimension 0. Target sizes: [309]. Tensor sizes: [535]

kilosort4_failedLog.log