SpikeInterface / spikeinterface

A Python-based module for creating flexible and robust spike sorting pipelines.
https://spikeinterface.readthedocs.io
MIT License
533 stars 187 forks source link

CUDA out of memory error from running kilosort #3321

Open Hobart10 opened 3 months ago

Hobart10 commented 3 months ago

Hello, When running kilosort4 via spikeinterface, CUDA out of memory error occurs immediately after the first clustering. The error still occurs when setting clear_cache = True. But not when running kilosort directly from its gui. Could you examine the error and provide any idea of how to fix? Thank you!! kilosort issue

Kilosort version: 4.0.15 spikeinterface version: 0.100.0.dev0

File ~\.conda\envs\SI\spikeinterface\src\spikeinterface\sorters\launcher.py:306, in run_sorter_by_property(sorter_name, recording, grouping_property, folder, mode_if_folder_exists, engine, engine_kwargs, verbose, docker_image, singularity_image, working_folder, **sorter_params)
    [295](https://file+.vscode-resource.vscode-cdn.net/d%3A/YK_pyfile/SI/SI_Proc/~/.conda/envs/SI/spikeinterface/src/spikeinterface/sorters/launcher.py:295)     job = dict(
    [296](https://file+.vscode-resource.vscode-cdn.net/d%3A/YK_pyfile/SI/SI_Proc/~/.conda/envs/SI/spikeinterface/src/spikeinterface/sorters/launcher.py:296)         sorter_name=sorter_name,
    [297](https://file+.vscode-resource.vscode-cdn.net/d%3A/YK_pyfile/SI/SI_Proc/~/.conda/envs/SI/spikeinterface/src/spikeinterface/sorters/launcher.py:297)         recording=rec,
   (...)
    [302](https://file+.vscode-resource.vscode-cdn.net/d%3A/YK_pyfile/SI/SI_Proc/~/.conda/envs/SI/spikeinterface/src/spikeinterface/sorters/launcher.py:302)         **sorter_params,
    [303](https://file+.vscode-resource.vscode-cdn.net/d%3A/YK_pyfile/SI/SI_Proc/~/.conda/envs/SI/spikeinterface/src/spikeinterface/sorters/launcher.py:303)     )
    [304](https://file+.vscode-resource.vscode-cdn.net/d%3A/YK_pyfile/SI/SI_Proc/~/.conda/envs/SI/spikeinterface/src/spikeinterface/sorters/launcher.py:304)     job_list.append(job)
--> [306](https://file+.vscode-resource.vscode-cdn.net/d%3A/YK_pyfile/SI/SI_Proc/~/.conda/envs/SI/spikeinterface/src/spikeinterface/sorters/launcher.py:306) sorting_list = run_sorter_jobs(job_list, engine=engine, engine_kwargs=engine_kwargs, return_output=True)
    [308](https://file+.vscode-resource.vscode-cdn.net/d%3A/YK_pyfile/SI/SI_Proc/~/.conda/envs/SI/spikeinterface/src/spikeinterface/sorters/launcher.py:308) unit_groups = []
    [309](https://file+.vscode-resource.vscode-cdn.net/d%3A/YK_pyfile/SI/SI_Proc/~/.conda/envs/SI/spikeinterface/src/spikeinterface/sorters/launcher.py:309) for sorting, group in zip(sorting_list, recording_dict.keys()):

File ~\.conda\envs\SI\spikeinterface\src\spikeinterface\sorters\launcher.py:106, in run_sorter_jobs(job_list, engine, engine_kwargs, return_output)
    [103](https://file+.vscode-resource.vscode-cdn.net/d%3A/YK_pyfile/SI/SI_Proc/~/.conda/envs/SI/spikeinterface/src/spikeinterface/sorters/launcher.py:103) if engine == "loop":
    [104](https://file+.vscode-resource.vscode-cdn.net/d%3A/YK_pyfile/SI/SI_Proc/~/.conda/envs/SI/spikeinterface/src/spikeinterface/sorters/launcher.py:104)     # simple loop in main process
    [105](https://file+.vscode-resource.vscode-cdn.net/d%3A/YK_pyfile/SI/SI_Proc/~/.conda/envs/SI/spikeinterface/src/spikeinterface/sorters/launcher.py:105)     for kwargs in job_list:
--> [106](https://file+.vscode-resource.vscode-cdn.net/d%3A/YK_pyfile/SI/SI_Proc/~/.conda/envs/SI/spikeinterface/src/spikeinterface/sorters/launcher.py:106)         sorting = run_sorter(**kwargs)
    [107](https://file+.vscode-resource.vscode-cdn.net/d%3A/YK_pyfile/SI/SI_Proc/~/.conda/envs/SI/spikeinterface/src/spikeinterface/sorters/launcher.py:107)         if return_output:
    [108](https://file+.vscode-resource.vscode-cdn.net/d%3A/YK_pyfile/SI/SI_Proc/~/.conda/envs/SI/spikeinterface/src/spikeinterface/sorters/launcher.py:108)             out.append(sorting)

File ~\.conda\envs\SI\spikeinterface\src\spikeinterface\sorters\runsorter.py:216, in run_sorter(sorter_name, recording, folder, remove_existing_folder, delete_output_folder, verbose, raise_error, docker_image, singularity_image, delete_container_files, with_output, output_folder, **sorter_params)
    [205](https://file+.vscode-resource.vscode-cdn.net/d%3A/YK_pyfile/SI/SI_Proc/~/.conda/envs/SI/spikeinterface/src/spikeinterface/sorters/runsorter.py:205)             raise RuntimeError(
    [206](https://file+.vscode-resource.vscode-cdn.net/d%3A/YK_pyfile/SI/SI_Proc/~/.conda/envs/SI/spikeinterface/src/spikeinterface/sorters/runsorter.py:206)                 "The python `spython` package must be installed to "
    [207](https://file+.vscode-resource.vscode-cdn.net/d%3A/YK_pyfile/SI/SI_Proc/~/.conda/envs/SI/spikeinterface/src/spikeinterface/sorters/runsorter.py:207)                 "run singularity. Install with `pip install spython`"
    [208](https://file+.vscode-resource.vscode-cdn.net/d%3A/YK_pyfile/SI/SI_Proc/~/.conda/envs/SI/spikeinterface/src/spikeinterface/sorters/runsorter.py:208)             )
    [210](https://file+.vscode-resource.vscode-cdn.net/d%3A/YK_pyfile/SI/SI_Proc/~/.conda/envs/SI/spikeinterface/src/spikeinterface/sorters/runsorter.py:210)     return run_sorter_container(
    [211](https://file+.vscode-resource.vscode-cdn.net/d%3A/YK_pyfile/SI/SI_Proc/~/.conda/envs/SI/spikeinterface/src/spikeinterface/sorters/runsorter.py:211)         container_image=container_image,
    [212](https://file+.vscode-resource.vscode-cdn.net/d%3A/YK_pyfile/SI/SI_Proc/~/.conda/envs/SI/spikeinterface/src/spikeinterface/sorters/runsorter.py:212)         mode=mode,
    [213](https://file+.vscode-resource.vscode-cdn.net/d%3A/YK_pyfile/SI/SI_Proc/~/.conda/envs/SI/spikeinterface/src/spikeinterface/sorters/runsorter.py:213)         **common_kwargs,
    [214](https://file+.vscode-resource.vscode-cdn.net/d%3A/YK_pyfile/SI/SI_Proc/~/.conda/envs/SI/spikeinterface/src/spikeinterface/sorters/runsorter.py:214)     )
--> [216](https://file+.vscode-resource.vscode-cdn.net/d%3A/YK_pyfile/SI/SI_Proc/~/.conda/envs/SI/spikeinterface/src/spikeinterface/sorters/runsorter.py:216) return run_sorter_local(**common_kwargs)

File ~\.conda\envs\SI\spikeinterface\src\spikeinterface\sorters\runsorter.py:276, in run_sorter_local(sorter_name, recording, folder, remove_existing_folder, delete_output_folder, verbose, raise_error, with_output, output_folder, **sorter_params)
    [274](https://file+.vscode-resource.vscode-cdn.net/d%3A/YK_pyfile/SI/SI_Proc/~/.conda/envs/SI/spikeinterface/src/spikeinterface/sorters/runsorter.py:274) SorterClass.set_params_to_folder(recording, folder, sorter_params, verbose)
    [275](https://file+.vscode-resource.vscode-cdn.net/d%3A/YK_pyfile/SI/SI_Proc/~/.conda/envs/SI/spikeinterface/src/spikeinterface/sorters/runsorter.py:275) SorterClass.setup_recording(recording, folder, verbose=verbose)
--> [276](https://file+.vscode-resource.vscode-cdn.net/d%3A/YK_pyfile/SI/SI_Proc/~/.conda/envs/SI/spikeinterface/src/spikeinterface/sorters/runsorter.py:276) SorterClass.run_from_folder(folder, raise_error, verbose)
    [277](https://file+.vscode-resource.vscode-cdn.net/d%3A/YK_pyfile/SI/SI_Proc/~/.conda/envs/SI/spikeinterface/src/spikeinterface/sorters/runsorter.py:277) if with_output:
    [278](https://file+.vscode-resource.vscode-cdn.net/d%3A/YK_pyfile/SI/SI_Proc/~/.conda/envs/SI/spikeinterface/src/spikeinterface/sorters/runsorter.py:278)     sorting = SorterClass.get_result_from_folder(folder, register_recording=True, sorting_info=True)

File ~\.conda\envs\SI\spikeinterface\src\spikeinterface\sorters\basesorter.py:301, in BaseSorter.run_from_folder(cls, output_folder, raise_error, verbose)
    [298](https://file+.vscode-resource.vscode-cdn.net/d%3A/YK_pyfile/SI/SI_Proc/~/.conda/envs/SI/spikeinterface/src/spikeinterface/sorters/basesorter.py:298)         print(f"{sorter_name} run time {run_time:0.2f}s")
    [300](https://file+.vscode-resource.vscode-cdn.net/d%3A/YK_pyfile/SI/SI_Proc/~/.conda/envs/SI/spikeinterface/src/spikeinterface/sorters/basesorter.py:300) if has_error and raise_error:
--> [301](https://file+.vscode-resource.vscode-cdn.net/d%3A/YK_pyfile/SI/SI_Proc/~/.conda/envs/SI/spikeinterface/src/spikeinterface/sorters/basesorter.py:301)     raise SpikeSortingError(
    [302](https://file+.vscode-resource.vscode-cdn.net/d%3A/YK_pyfile/SI/SI_Proc/~/.conda/envs/SI/spikeinterface/src/spikeinterface/sorters/basesorter.py:302)         f"Spike sorting error trace:\n{error_log_to_display}\n"
    [303](https://file+.vscode-resource.vscode-cdn.net/d%3A/YK_pyfile/SI/SI_Proc/~/.conda/envs/SI/spikeinterface/src/spikeinterface/sorters/basesorter.py:303)         f"Spike sorting failed. You can inspect the runtime trace in {output_folder}/spikeinterface_log.json."
    [304](https://file+.vscode-resource.vscode-cdn.net/d%3A/YK_pyfile/SI/SI_Proc/~/.conda/envs/SI/spikeinterface/src/spikeinterface/sorters/basesorter.py:304)     )
    [306](https://file+.vscode-resource.vscode-cdn.net/d%3A/YK_pyfile/SI/SI_Proc/~/.conda/envs/SI/spikeinterface/src/spikeinterface/sorters/basesorter.py:306) return run_time

SpikeSortingError: Spike sorting error trace:
Traceback (most recent call last):
  File "C:\Users\Lenovo\.conda\envs\SI\spikeinterface\src\spikeinterface\sorters\basesorter.py", line 261, in run_from_folder
    SorterClass._run_from_folder(sorter_output_folder, sorter_params, verbose)
  File "C:\Users\Lenovo\.conda\envs\SI\spikeinterface\src\spikeinterface\sorters\external\kilosort4.py", line 273, in _run_from_folder
    st, tF, _, _ = detect_spikes(ops, device, bfile, tic0=tic0, progress_bar=progress_bar)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\Lenovo\.conda\envs\SI\Lib\site-packages\kilosort\run_kilosort.py", line 611, in detect_spikes
    st, tF, ops = template_matching.extract(ops, bfile, Wall3, device=device,
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\Lenovo\.conda\envs\SI\Lib\site-packages\kilosort\template_matching.py", line 26, in extract
    ctc = prepare_matching(ops, U)
          ^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\Lenovo\.conda\envs\SI\Lib\site-packages\kilosort\template_matching.py", line 108, in prepare_matching
    ctc = torch.einsum('ijkm, kml -> ijl', UtU, WtW)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\Lenovo\.conda\envs\SI\Lib\site-packages\torch\functional.py", line 380, in einsum
    return _VF.einsum(equation, operands)  # type: ignore[attr-defined]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 53.24 GiB. GPU 0 has a total capacity of 10.00 GiB of which 0 bytes is free. Of the allocated memory 31.18 GiB is allocated by PyTorch, and 43.14 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
alejoe91 commented 3 months ago

The difference with the spikeinterface version is that we use the kilosort RecordingExtractorAsArray wrapper, so I suspect it's a problem with it.

What we can do is to support an option to run from a binary file (and providing the path to the binary file) in case the recording is binary.

Hobart10 commented 3 months ago

Hi Alessio, Thank you for the reply! So the pre-process and property setting steps will be done by kilosort instead of spikeinterface in the option?

alejoe91 commented 3 months ago

No the difference is that Kilosort has two modes of parsing the recording to its pipeline:

  1. from a binary file
  2. using an object that directly wraps a SpikeItnerface Recording (RecordingExtractorAsArray)

In the wrapper, we use option 2, but the KS GUI uses option 1, so you need to have a binary file. The spikeinterface wrapper uses option 2, but we could switch to option 1 in case the recording is already binary (so it should have the exact same behavior as running from the KS GUI).

@JoeZiminski @zm711 @chrishalcrow what do you think?

zm711 commented 3 months ago

I think the issue is that although KS4 accepts spikeinterface objects (method 2). They don't really have optimizations for using our objects so the experience is probably less than ideal. They have notes in their docs saying that they accept our stuff but not optimized so they know they are not using it as efficiently as they could.

This makes me lean toward just doing the binary format (or we could give a flag use_binary_file) which is optimized for the KS4 code. So I would vote for 1 or giving the user the option between 1 or 2 with a flag set to default to 1.

chrishalcrow commented 3 months ago

If the user chose use_binary_file and it was unavailable, should we generate a binary file (ala mountainsort??)?

zm711 commented 3 months ago

Yep that's what I was thinking. I believe KS2,2.5,and 3 also do that.

alejoe91 commented 3 months ago

Deal! I actually always wanted to do that, but have been too lazy :P

samuelgarcia commented 3 months ago

cool but bu I think I would would the default with no bin copy so object api in ks4.

JoeZiminski commented 2 months ago

Sorry for slow reply, this looks good! I think I'd also lean towards using binary by default for now even if it will result in a copy of the recording so that for now everyone is funnelled by default down the same default KS4 pipeline.

Using an SI wrapper to avoid this copy is infinitely nicer, but until we have the test framework to ensure it is performing well across CPU and GPU, it seems safer to use their default. It's great of them to provide this wrapper but it is not explicitly supported, so I would feel more comfortable funnelling spikeinterface users through their explicitly supported pipeline by default.

alejoe91 commented 2 months ago

@Hobart10 could you test again using the main branch and KS 4.0.16?

When you run the sorter, you should add this argument: use_binary_file=True

Hobart10 commented 2 months ago

Thank you for the updates! However, I still get the same error with the updates, which is not appear to be a PC or data file issue. Also, running from SI is much slower than KS gui, so still KS is not running from binary file?

Execution log: KS4.log, SI_OOM.log Example data (~15GB): bin file link, probe.zip

JoeZiminski commented 2 months ago

Hi @Hobart10, thanks for helping to get to the bottom of this.

can you double-check: 1) that you definitely don't get the error when running kilosort 4 through the GUI on your raw data file, that the sorting completes successfully. 2) save the spikeinterface pre-processed recording in spikeinterface to binary (recording.save(..., format="binary"). Then load from the saved binary (the .dat file in the saved recording folder) either through KS4 GUI or their run_kilosort function (e.g. here. (basically, skip all SI wrapping and see if it is breaking from on the SI-saved binary).

I have a colleague who is experiencing a similar error even when running the raw data directly in KS4. It would be good to find out whether this a) a problem with the spikeinterface wrapper b) a native KS4 problem c) in this test ensure that KS4 is definitely getting the binary and not defaulting to the spikeinterface wrapper c) some strange interaction between the spikeinterface preprocessing and KS4 (i.e. the error does not occur when running from the raw binary through KS4 but does occur on a spikeinterface preprocessed / saved binary). Could you also post the script you are running?

Hobart10 commented 2 months ago

Thank you for the instructions! @alejoe91 @JoeZiminski

  1. Both the file recording.dat generated upon setting use_bin_file = True, and traces_cached_seg0.raw (no .bin file) saved with recording.save are actually just noise signals. Running the two files in KS gui are extremely slow (>200s/it) when extracting spikes using cluster waveform, which may actually be the reason of OOM error in SI? Somehow probe setting makes the recording noise in my case.
  2. Running the original binary file directly in KS gui, there are no errors. Running in SI with probe setting by run_sorter or run_sorter_by_property is slow and still having OOM error.

KS gui view of recording.dat: Kilosort4_dat

KS gui view of preprocessed recording: traces_cached_seg0.raw: Kilosort4_raw

KS gui view of original bin file: Kilosort4_oriBin

Necessary script:

import spikeinterface.full as si
import probeinterface as pi
import numpy as np
import os
import matplotlib
matplotlib.use('TkAgg')
import matplotlib.pyplot as plt
import warnings
warnings.simplefilter("ignore")
from pathlib import Path

# Params definition
n_jobs = os.cpu_count() - 4
global_job_kwargs = dict(n_jobs=n_jobs, chunk_duration="1s", progress_bar=True)
si.set_global_job_kwargs(**global_job_kwargs)
path_SIout = folderPath / 'SIproc'
if not os.path.exists(path_SIout):
    os.mkdir(path_SIout)

## SI Recording object:
rec_raw = si.read_binary(file_paths=fullPath_bin, sampling_frequency=sr, num_channels=nChan_logger, dtype='int16', time_axis=1)
nChan_raw = len(rec_raw.get_channel_ids())

# probe info
nChan_probe = 64
grouping = np.repeat([0], nChan_probe)
manufacturer = 'NeuroNexus'
probeName = 'V1x64-Edge-10mm-60-177-VC64'
chanMap = [49, 17, 51, 19, 53, 21, 55, 23,  #prb 1-8
           57, 25, 59, 27, 61, 29, 63, 31,  #prb 9-16
           46, 14, 44, 12, 42, 10, 40, 8 ,  #prb 17-24
           38, 6 , 36, 4 , 34, 2 , 32, 0 ,  #prb 25-32
           30, 62, 28, 60, 26, 58, 24, 56,  #prb 33-40
           22, 54, 20, 52, 18, 50, 16, 48,  #prb 41-48
           1 , 33, 3 , 35, 5 , 37, 7 , 39,  #prb 49-56
           9 , 41, 11, 43, 13, 45, 15, 47]  #prb 57-64
probe = pi.get_probe(manufacturer, probeName) # Probe object
probe.set_device_channel_indices(chanMap)
rec_raw.set_probe(probe, in_place=True)
rec_raw.set_property(key=brainArea, values=[brainArea]*(nChan_raw))
rec_raw.set_property("group", grouping)

# NOTE: Test with non-grouping preproc
# Save preprocessed recording
folder_preproc = Path(path_SIout) / "prepRec_bin"
if (folder_preproc).is_dir():
    print('Loading pre-processed data from path.')
    rec_saved = si.read_binary(file_paths=folder_preproc, sampling_frequency=sr, num_channels=nChan_logger, dtype='int16', time_axis=1)
    # rec_saved = si.load_extractor(folder_preproc) #if already exist, load
else:
    rec_filter   = si.highpass_filter(rec_raw, freq_min=300)
    # rec_filter = si.bandpass_filter(rec_raw, freq_min=300, freq_max=6000)
    rec_prep = si.common_reference(rec_filter, reference='global', operator='median')
    print('Saving pre-processed data.')
    rec_prep.save(format = "binary", folder=folder_preproc, **global_job_kwargs)
    rec_saved = rec_prep
    # rec_saved = rec_prep.save(folder=folder_preproc, **global_job_kwargs)

nChan = len(rec_saved.get_channel_ids())
sr = int(rec_saved.sampling_frequency)
ks4_params_modSet = {'do_correction': False, 'nearest_templates': nChan,
                        'do_CAR':False, 'nblocks': 0, 'batch_size': sr*4, 
                        'Th_universal': 7, 'Th_learned': 6, 'dmin':10, 'dminx': 20, 
                        'skip_kilosort_preprocessing': True, 'use_binary_file': True}
# run spike sorting on entire recording
sorting_KS4 = si.run_sorter_by_property(sorter_name='kilosort4', recording=rec_saved,  
                                        grouping_property='group', folder=Path(path_SIout) / 'results_KS4', 
                                        verbose=True, **ks4_params_modSet)
# print(sorting_KS4)
JoeZiminski commented 2 months ago

Hi @Hobart10 thanks for this. At a guess it seems like something in the SI pipeline is getting messed up. It might be worth using si.plot_traces to look at the data through the pipeline to check it looks okay at each step. I am not sure where the problem is occurring. But it seems that the input data quality is the cause of the CUDA problems in KS4, I guess it is just not expecting such data.

alejoe91 commented 2 months ago

@Hobart10 can you try with read_binary(..., time_axis=0)? Why are you using time_axis=1 in the first place?