SpikeInterface / spikeinterface

A Python-based module for creating flexible and robust spike sorting pipelines.
https://spikeinterface.readthedocs.io
MIT License
527 stars 187 forks source link

Sorting takes extremely long when sorting a four shank probe by property #2625

Closed guidomeijer closed 6 months ago

guidomeijer commented 7 months ago

Hi there! I'm trying to sort a four-shank Neuropixel recording with Kilosort4 but when I try to use run_sorter_by_property it takes 8 hours to sort one of the four shanks and an estimated 160 hours to recompute the spike templates. When I split the recording by shank and do run_sorter on a single shank it runs normally. I'm using a local installation of Kilosort4 which has access to the GPU (I checked). Any idea what might be going on? I had a related issue on the Kilosort4 github but it doesn't seem to be a Kilosort issue: https://github.com/MouseLand/Kilosort/issues/631

Setup: NVIDIA RTX 4080 64 GB RAM Ubuntu 20.04 kilosort 4.0.2 spikeinterface 0.100.2

Code:

            # Apply high-pass filter
            print('Applying high-pass filter.. ')
            rec_filtered = spre.highpass_filter(rec)

            # Correct for inter-sample phase shift
            print('Correcting for phase shift.. ')
            rec_shifted = spre.phase_shift(rec_filtered)

            # Detect and interpolate over bad channels
            print('Detecting and interpolating over bad channels.. ')
            bad_channel_ids, all_channels = spre.detect_bad_channels(rec_shifted)

            # If there are too many bad channels, skip the interpolation step
            prec_bad_ch = np.sum(all_channels == 'noise') / all_channels.shape[0]
            if prec_bad_ch < (1/3):
                rec_interpolated = spre.interpolate_bad_channels(rec_shifted, bad_channel_ids)
                print(f'{np.sum(all_channels == "noise")} ({prec_bad_ch*100:.0f}%) bad channels')
            else:
                rec_interpolated = rec_shifted
                print(f'{np.sum(all_channels == "noise")} ({prec_bad_ch*100:.0f}%) bad channels,',
                      'skipping the interpolation step')

            # If there are multiple shanks, do destriping per shank
            print('Destriping.. ')
            if np.unique(rec_interpolated.get_property('group')).shape[0] > 1:

                # Loop over shanks and do preprocessing per shank
                rec_split = rec_interpolated.split_by(property='group')
                rec_destriped = []
                for sh in range(len(rec_split)):
                    rec_destriped.append(spre.highpass_spatial_filter(rec_split[sh]))

                # Merge back together
                rec_final = si.aggregate_channels(rec_destriped)

                # Run spike sorting per shank 
                try:
                  print(f'\nStarting {split(probe_path)[-1]} spike sorting at {datetime.now().strftime("%H:%M")}')
                  sort = run_sorter_by_property(
                      sorter_name=settings_dict['SPIKE_SORTER'],
                      recording=rec_final,
                      grouping_property='group',
                      working_folder=join(probe_path, settings_dict['SPIKE_SORTER'] + id_str),
                      verbose=True,
                      docker_image=settings_dict['USE_DOCKER'],
                      **sorter_params)

                except Exception as err:
                    # Log error to disk
                    print(err)
                    logf = open(os.path.join(probe_path, 'error_log.txt'), 'w')
                    logf.write(str(err))
                    logf.close()

                    # Continue with next recording
                    continue
zm711 commented 7 months ago

This is good info to have. I thought maybe this was a Windows only issue, but maybe there is a global problem with the KS4 wrapper @alejoe91?

@guidomeijer (for background also see #2569), I was doing some KS4 testing and finding similar problems. But since all my data is multishank I had just set this aside. I'm busy this week, but maybe I'll pick this back up and work on it some more unless Alessio figures it out first.

alejoe91 commented 7 months ago

Hi @guidomeijer

I'll take a look! Maybe the GPU capability is not propagated correctly when running by property, which makes KS4 run on CPU. That's the only think I can think of.

alejoe91 commented 7 months ago

@zm711 the run_sorter_by_property is basically a map to the run_sorters_jobs, so we probably need to focus there..I'll try to reproduce the issue locally over the next days

guidomeijer commented 7 months ago

I noticed that it does allocate some GPU memory (500-700 MB) but it's much lower than if you do run_sorter

zm711 commented 7 months ago

How much memory do you have? Is it maybe doing something where it is dividing the memory by the number of shanks?

Or how many n_jobs (maybe memory divded/n_jobs)?

guidomeijer commented 7 months ago

The GPU has 16 GB of memory. I didn't specify the n_jobs parameter anywhere.

zm711 commented 7 months ago

The default is to use all available so if you don't change it, it will use all cpus. So how many cpu cores do you have?

guidomeijer commented 7 months ago

It's an AMD Ryzer Threadripper with 16 cores. But when I look at the usage it's only using one of the cores at the same time.

zm711 commented 7 months ago

Maybe there is some sort of interaction because 16 GB for the GRAM/16 cores would give 1GB or less/core and then if it is only using one core then maybe that is part of the issue. I'll try to read into it bit more. When you do run_sorter how many cores and GPU memory are engaged?

guidomeijer commented 7 months ago

Still only one core at a time and 2 GB of GPU memory. But it processes a recording very fast (~2 hours).

zm711 commented 7 months ago

I've been trying to reread deeply into how Kilosort uses get_traces, and I can't see why it would be different between creating the ChannelSliceRecording before run_sorter or inside of run_sorter_by_property. I've run a bunch of testing and at each step inside our wrapper it says device being used is 'cuda' so it is not switching to 'cpu' in between steps. @alejoe91, I'm happy to test anything out, but I'm really not sure for this one.

zm711 commented 7 months ago

@alejoe91,

Actually thinking about this a bit more, I think this could be related to the Mountainsort5 caching issue. When I do run_sorter_by_property mountainsort5 still writes the whole recording rather than the sub_recording for its caching. I think the same thing could be happening here. Where even though we are giving it a sub_recording the RecordingExractorAsArray in Kilosort might be looking at the whole recording rather than the sub. Any ideas why that might be? (again this makes sense to me in that 2hours/shank = ~8 hours for 4 shank).

Just to elaborate a bit more let's say I have a 10 gb file and I'm splitting in half. When I do run_sorter_by_property for KS2 or KS3 the recording.dat is ~5gb. Doing the exact same thing with MS5 leads to a recording.dat of 10 gb. So if we stay in python we either write the whole recording erroneously or we write the half data twice.

alejoe91 commented 7 months ago

Thanks @zm711

This could be a lead. Each run sorter job should get a channel slice object, so it shouldn't "know" that there are more channels. Maybe there is indeed something wrong in the splitting and job distribution! I'll take a look

zm711 commented 7 months ago

False start on my part. I added a bunch of prints to check the status and the reason why ms5 is doubling in size and it is that it is being cast to a float from the uint dtype of the recording. I'll keep searching in KS4 when I have free moments, but still not sure.

BovenE commented 7 months ago

Hi everyone, does anyone has some new information regarding this issue? I have been having similar issues using kilosort4 through spikeinterface on a multi-shank probe. Thank you!

guidomeijer commented 7 months ago

Kilosort 4.0.4 should now handle multiple shanks (https://github.com/MouseLand/Kilosort/issues/641#issuecomment-2053769305) so you can also just sort the whole recording in one go instead of sorting by property. I'm testing it now.

guidomeijer commented 7 months ago

Update: using Kilosort 4.0.4. still does not work for me when sorting 4-shank probes. It took 16 hours just to complete the first step and there are many more steps to go.

zm711 commented 7 months ago

Couple questions. When you start from KS4 directly do you use a binary file? How did you write the binary file if you have it written?

When sorting are you doing it locally or some sort of network/server mount?

guidomeijer commented 7 months ago

I use the binary file that comes directly out of SpikeGLX. I'm doing the sorting locally on my computer.

zm711 commented 7 months ago

The part that I'm struggling with is that at the beginning of the issue is said that splitting the recoding and running run_sorter on the pre split is the same speed as KS4 but run_sorter_by_property is slow. So I was wondering if the binary file vs making a spikeinterface was the slow thing. My logic being that with a straight binary file they just call to the binary file, but with a spikeinterface array they call to a class they made which then calls a function that calls to our class which calls a function. So we are running through an extra 3-4 layers of python in that case. But if that is really slowing things down I would expect a similar slow down for run_sorter and run_sorter_by_property. Could you re-confirm your test conditions speed. Something like: KS4 native: ~1 hour/shank KS4 run_sorter (pre-split)... KS4 run_sorter_by_property....

gkBCCN commented 7 months ago

I found this thread because I am having the same issue. If I run kilosort directly on a single probe's .dat file, the run time is about the same as the recording time. However, if I have both probes loaded in SpikeInterface, each probe's sorting takes about 4x longer. When I open a single probe's results in Phy, I can actually see 2 probes in ProbeView.
image

So I'm inclined to think that both probes are somehow present. Maybe this has a weird interaction with Kilosort's native handling of multiple probes as @guidomeijer mentioned.

I see this issue on both Ubuntu 22.04 and Windows 10.

I've also had problems with running multiple jobs on Windows (e.g., when running the analyzer). Not sure if that plays a role here, but this is in reference to @alejoe91's comment:

Thanks @zm711

This could be a lead. Each run sorter job should get a channel slice object, so it shouldn't "know" that there are more channels. Maybe there is indeed something wrong in the splitting and job distribution! I'll take a look

In the meantime, am I to understand that using si.run_sorter is a working alternative?

zm711 commented 7 months ago

@gkBCCN, what recording are you using? What probe type? Seeing 2 probes is weird and we could look into that more.

gkBCCN commented 7 months ago

Hey @zm711 . I'm using the SpikeGadgets .rec format that I added recently and my file has 2 Neuropixels1 probes.

guidomeijer commented 7 months ago

I think I found the issue in my case. If I split up the probe in shanks to run the destriping per shank and then merge the result back together into one recording, everything after that takes an insane amount of time. Even just plotting the traces doesn't work. If I skip this step, or do the destriping on the whole recording without splitting it up, it runs fine.

This is the code that causes the slow-down:

            # If there are multiple shanks, do destriping per shank
            print('Destriping.. ')
            if np.unique(rec_interpolated.get_property('group')).shape[0] > 1:

                # Loop over shanks and do preprocessing per shank
                rec_split = rec_interpolated.split_by(property='group')
                rec_destriped = []
                for sh in range(len(rec_split)):
                    rec_destriped.append(spre.highpass_spatial_filter(rec_split[sh]))

                # Merge back together
                rec_final = si.aggregate_channels(rec_destriped)

Bear in mind when I say it runs fine I am talking about using run_sorter. I think run_sorter_by_property is still very slow but I haven't checked this recently because with Kilosort 4.0.4 it's not necessary anymore.

alejoe91 commented 7 months ago

Hi guys,

I'm also looking into this and try to reproduce the performance issue. I'm quite convinced, as @guidomeijer said, that this is mainly due to preprocessing.

Here's a simple benchmark using simulated data (artificially split into 4 groups):

# create simulated recording
rec, sort = si.generate_ground_truth_recording(num_channels=128, durations=[150], num_units=60)
num_groups = 4
channels_per_group = rec.get_num_channels() // num_groups
channel_groups = []
for i in range(num_groups):
    channel_groups.extend([i] * channels_per_group)
rec.set_channel_groups(channel_groups)

print("\n\nRUNNING ALL\n\n")
t_start = time.perf_counter()
sorting_ks4_all = ss.run_sorter("kilosort4", rec, output_folder="ks4_all")
t_stop = time.perf_counter()
elapsed_all = np.round(t_stop - t_start, 2)
print(f"Elapsed time all: {elapsed_all} s")

print("\n\nRUNNING LOOP\n\n")
t_start = time.perf_counter()
rec_dict = rec.split_by("group")
for g, rec_g in rec_dict.items():
    sorting_ks4_loop = ss.run_sorter("kilosort4", rec_g, output_folder=f"ks4_loop{g}")
t_stop = time.perf_counter()
elapsed_loop = np.round(t_stop - t_start, 2)
print(f"Elapsed time loop: {elapsed_loop} s")

print("\n\nRUNNING BY PROPERTY\n\n")
t_start = time.perf_counter()
sorting_ks4_prop = ss.run_sorter_by_property("kilosort4", rec, grouping_property="group", working_folder="ks4_prop")
t_stop = time.perf_counter()
elapsed_prop = np.round(t_stop - t_start, 2)
print(f"Elapsed time by property: {elapsed_prop} s")

And these are the printed elapsed times:

RUNNING ALL
Elapsed time all: 51.15 s

RUNNING LOOP
Elapsed time loop: 132.78 s

RUNNING BY PROPERTY
Elapsed time by property: 130.43 s

So there is an overhead in running by group, but I think it's due to the overhead of running KS 4 times rather than once.

@guidomeijer I'll look into why the hishpass spatial filter behaves so differently if you split by group!

alejoe91 commented 7 months ago

@gkBCCN can you share your code?

gkBCCN commented 7 months ago

Do you mean the preprocessing by KS4 or by SpikeInterface? All steps were slower in my case (templates, clustering, etc.).

alejoe91 commented 7 months ago

Do you mean the preprocessing by KS4 or by SpikeInterface? All steps were slower in my case (templates, clustering, etc.).

The SpikeInterface code up to the sorting run

gkBCCN commented 7 months ago

I just followed the tutorial:

sorting = si.run_sorter_by_property( sorter_name=sorter_algorithm, recording=recording, grouping_property='group', working_folder=sorting_folder )

where "recording" is raw_dat = read_spikegadgets(rec_file), which is then high pass filtered and common referenced.

BTW: there's a typo in the example on https://github.com/SpikeInterface/spikeinterface/blob/main/doc/how_to/process_by_channel_group.rst, Option 1: Manual splitting:

split_preprocessed_recording = preprocessed_recording.split_by("group")

sortings = {} for group, sub_recording in split_preprocessed_recording.items(): sorting = run_sorter( sorter_name='kilosort2', recording=split_preprocessed_recording, output_folder=f"folder_KS2_group{group}" ) sortings[group] = sorting

It should be sub_recording instead of split_preprocessed_recording.

gkBCCN commented 7 months ago

OK, if I use Option 1 - Manual Splitting (as described in my previous comment), which is the si.run_sorter option mentioned in the original post, the sorting completes in roughly the same time as the GUI, albeit a bit slower. Thanks for that, @guidomeijer!

gkBCCN commented 7 months ago

Hold the phone. Even if I use run_sorter on each probe separately, the binary file that is created by export_to_phy is double the size, as @zm711 noted. That could explain why phy sees two probes. After the sorting is finished, I loop over both probe numbers and load the results, then create an analyzer using the entire recording.

for probe_num in range(1, len(recording.get_probes())+1):
    sorting = si.read_sorter_folder(sorting_folder / f'{probe_num-1}')
    analyzer = si.create_sorting_analyzer(sorting, recording, sparse=True, format="memory")

My bad. I guess I should use sub_recording as I did in the sorting loop, correct?

alejoe91 commented 7 months ago

Yes, if you want one phy folder for each probe. It might be a good approach also to export everything into one Phy folder

alejoe91 commented 7 months ago

Hi guys, I found and fixed the issue!!!

@guidomeijer the problem was in the aggregate channels (see PR #2736 ). The get_traces was grabbing one channel at a time...since the highpass spatial filter uses all channels for processing, this made it extremely slow.

The new implementation is 10 times faster: image

@guidomeijer @gkBCCN Can you try the aggregate_channels + run_sorter_by_property from the PR?

gkBCCN commented 7 months ago

Hi @alejoe91 I created a new conda environment with your changes and ran all three versions and this is what I got:

Recording time: 00:06:26.6 h:m:s

Kilosort GUI: Total = 324.52s = 00:05:25 h:m:s preprocessing = 0.64s drift = 98.42s extracting spikes using templates = 91.09s 1st clustering = 47.34s extracting spikes using cluster waveforms = 30.10s final clustering = 51.01s

Single probe (run_sorter): Total = 469.16s = 00:07:49 h:m:s preprocessing = 3.21s drift = 119.10s extracting spikes using templates = 106.68s 1st clustering = 105.24s extracting spikes using cluster waveforms = 50.54s final clustering = 78.57s

Aggregate (run_sorter_by_property): Total = 1265.55s = 00:21:6 h:m:s preprocessing = 19.15s drift = 398.12s extracting spikes using templates = 372.54s 1st clustering = 94.45s extracting spikes using cluster waveforms = 299.04s final clustering = 76.56s

zm711 commented 7 months ago

Based on my reading it looks like single probe vs aggregate is now scaling appropriately. I think that the overhead of using the wrapper will mean that running this through SI would be expected to be slightly slower, but you get the benefit of sorting shanks individually rather than relying on KS4 trying to make the multiple shanks work (like with the stacking trick that was previously necessary). That's my interpretation at least.

gkBCCN commented 7 months ago

But the aggregate time is not for both probes. These are all processing times for a single probe.

zm711 commented 7 months ago

Sorry I misunderstood that. I thought it was for multiple probes!

zm711 commented 7 months ago

Do you know what your time for a run_sorter_by_property is without the PR. Is the ~1200 seconds a 10x speedup from before? There might be something else in addition that is still slowing things down, but did the PR cause an improvement from the baseline?

alejoe91 commented 7 months ago

@gkBCCN how are you using the aggregate function to run with a single probe?

gkBCCN commented 7 months ago

@alejoe91 I just interrupted the kernel after one probe finished. But KS reports the times as it's running, which I wrote above. @zm711 I'm currently running the non-PR run_sorter_by_property. It should be done soon-ish...

gkBCCN commented 7 months ago

@zm711 non PR Aggregate (run_sorter_by_property): Total = 1227.03s = 00:20:27 h:m:s preprocessing = 13.31s drift = 381.12s extracting spikes using templates = 359.76s 1st clustering = 103.20s extracting spikes using cluster waveforms = 299.05s final clustering = 78.12s

alejoe91 commented 7 months ago

So it's the same. I guess that's because your using CMR, which is way faster than highpass_spatial_filtering..I'll run some more tests on my side

gkBCCN commented 7 months ago

Just to be clear, I'm using SpikeGadgets .rec files as input.

alejoe91 commented 7 months ago

Yep, but it's a memmap file so I don't think that makes a difference

zm711 commented 7 months ago

Honestly at this point it might make sense for us to try a profiler so we can see which step is taking long. If @gkBCCN knows how to use a profiler you could try it. I'm working on some other analysis but next week I can try profiling a call to kilosort4 on some of my data that I use on run_sorter_by_property. Then hopefully we can see which class is causing the bottle neck on our side.

alejoe91 commented 7 months ago

@gkBCCN how long does it take to run both probes in KS directly or with the run_sorter_by_property?

Note that an overhead is expected for sure, because of all the machinery in place to initialize and transfer data back and forth to the GPU...

gkBCCN commented 7 months ago

In my experience so far it always takes twice as long for a second probe when running run_sorter_by_property. That will necessarily be the case when running KS directly, because I run it twice, each time on a separate .dat file, but both have the same size.

alejoe91 commented 7 months ago

Anyways, @guidomeijer I tested again and your problem with destriping+aggregate should be solved.

Here are some run times on a 384-channel recording:

Again, the overhead of running multiple probe is expected, but IMO it will give better results, especially when applying probe-specific preprocessing as in @guidomeijer example and for drift correction!

guidomeijer commented 6 months ago

Yes it's solved, thanks!