MouseLand / Kilosort

Fast spike sorting with drift correction for up to a thousand channels
https://kilosort.readthedocs.io/en/latest/
GNU General Public License v3.0
475 stars 247 forks source link

Poor performance in kilosort4 for rectangular and hexagonal arrays. #663

Closed mikemanookin closed 6 months ago

mikemanookin commented 7 months ago

Describe the issue:

My lab is using multielectrode arrays to record from the retina. These arrays (Litke arrays from UC Santa Cruz) are either rectangular (1x2 mm; 60 um pitch) or hexagonal (30um or 120 um pitch). The recordings that I've been testing produced ~1000 good units using kilosort2.5 (a good experiment for us), but I'm only getting around 130 good units with kilosort4. I have read your latest paper and have played with several different parameters, but nothing seems to improve things. I am particularly confused about how to set the 'dmin', 'dminx', and 'min_template_size' parameters for these rectangular or hexagonal arrays. If I set 'dminx' to 60, for example, I get 135 good units, but if I set it to 120, I run out of memory on my GPU (48 GB RAM)... Any guidance that you could give would be greatly appreciated.

This is the python code snippet I've been using to test kilosort4:

import torch
from kilosort import run_kilosort, io

settings = {'n_chan_bin': 512,
            'fs': 20000,
            'batch_size': 60000,
            'Th_universal': 9,
            'Th_learned': 8,
            'whitening_range': 32, # Number of nearby channels to include in the whitening
            'dmin': 90, # Check this...
            'dminx': 90, # Check this...
            'min_template_size': 10, # Check this...
            'n_pcs': 3, # default is 6, we used 3 before
            'Th_single_ch': 4.5} # Default is 6... we used 4.5 before

probe = io.load_probe('/home/mike/Documents/git_repos/manookin-lab/MEA/src/pipeline_utilities/kilosort/LITKE_512_ARRAY.mat')

ops, st, clu, tF, Wall, similar_templates, is_ref, est_contam_rate = \
    run_kilosort(settings=settings, 
                 probe=probe,
                 filename='/data/data/sorted/20240401C/chunk1.bin',
                 results_dir='/data/data/sorted/20240401C/chunk1/kilosort4/',
                 data_dtype='int16', # Check the data type
                 device = torch.device('cuda:1'),
                 invert_sign = False, # Check this; Invert the sign of the data as expected by kilosort4 (was False)
                 do_CAR=True)
jacobpennington commented 7 months ago

@mikemanookin I'm guessing this is related to some other issues we've found with multi-shank and 2D MEA probes. Am I reading your description correctly, that each channel you're recording from is 30 to 120um away from all other channels?

mikemanookin commented 7 months ago

Yes, that is correct. I am happy to send the probe file with the geometry, but for the rectangular array each electrode is 60 um apart from its nearest neighbor on a grid that is 1x2 mm.

jacobpennington commented 7 months ago

Okay, then yeah this is related to a known issue with how we've been determining template placement and grouping. The fix is working well on multi-shank probes so far, but still needs more testing for 2D grids. If you wouldn't mind sharing the probe file(s), that would help me test more cases.

mikemanookin commented 7 months ago

Of course. Here are the probe files for the three arrays I use... Thanks so much for your help.

Best, Mike

Archive.zip

jacobpennington commented 6 months ago

Thanks, this was helpful! The changes I made to address this are merged now, and live on pypi as version 4.0.4. Handling 2D arrays automatically is still a work in progress, but I think you should be able to sort effectively now by setting the new x_centers parameter. I would try x_centers = 10, with the goal of grouping templates in ~200 micron sections horizontally. If performance seems exceptionally slow, try increasing that.

Please let us know if this works for you!

mikemanookin commented 6 months ago

Thank you for doing that, Jacob. I tried running kilosort with the following parameters, but it ran out of memory. Do you have any advice for the parameters that I should use for this (60 um pitch) array? I will test it out a again with different parameters, but having some guidance on this would be very helpful. Thanks in advance.

These are the parameters I used.

settings = {'n_chan_bin': 512,
            'fs': 20000,
            'batch_size': 60000,
            'Th_universal': 9,
            'Th_learned': 8,
            'whitening_range': 32, # Number of nearby channels to include in the whitening
            'dmin': 120, # Check this...
            'dminx': 120, # Check this...
            'min_template_size': 30, # Check this...
            'n_pcs': 6, # default is 6, we used 3 before
            'Th_single_ch': 4.5, # Default is 6... we used 4.5 before
            'x_centers': 10 # Number of x-positions to use when determining center points
            } 

This is the stack trace:

Final clustering
  0%|                                                                                                                                 | 0/4 [4:49:39<?, ?it/s]
Traceback (most recent call last):
  File "<stdin>", line 2, in <module>
  File "/home/mike/Documents/git_repos/Kilosort4/Kilosort/kilosort/run_kilosort.py", line 143, in run_kilosort
    clu, Wall = cluster_spikes(st, tF, ops, device, bfile, tic0=tic0,
  File "/home/mike/Documents/git_repos/Kilosort4/Kilosort/kilosort/run_kilosort.py", line 419, in cluster_spikes
    clu, Wall = clustering_qr.run(ops, st, tF,  mode = 'template', device=device,
  File "/home/mike/Documents/git_repos/Kilosort4/Kilosort/kilosort/clustering_qr.py", line 349, in run
    iclust, iclust0, M, iclust_init = cluster(Xd, nskip=nskip, lam=1,
  File "/home/mike/Documents/git_repos/Kilosort4/Kilosort/kilosort/clustering_qr.py", line 131, in cluster
    iclust_init =  kmeans_plusplus(Xg, niter = nclust, seed = seed, device=device)
  File "/home/mike/Documents/git_repos/Kilosort4/Kilosort/kilosort/clustering_qr.py", line 155, in kmeans_plusplus
    vtot = (Xg**2).sum(1)
  File "/home/mike/anaconda3/envs/kilosort/lib/python3.9/site-packages/torch/_tensor.py", line 40, in wrapped
    return f(*args, **kwargs)
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 19.06 GiB. GPU 0 has a total capacity of 47.51 GiB of which 14.57 GiB is free. Process 32564 has 12.66 GiB memory in use. Including non-PyTorch memory, this process has 20.26 GiB memory in use. Of the allocated memory 19.40 GiB is allocated by PyTorch, and 319.15 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
jacobpennington commented 6 months ago

Hmm okay, thanks. As for parameter settings: you could try reducing the batch size to 40000 (2s of data), though I doubt that will be enough to handle the memory issue here. min_template_size should be set to the default value of 10 again, some past suggestions about increasing that were workarounds that shouldn't be necessary with the new changes. You should start with Th_single_ch at the default value, the thresholds don't have quite the same scale as in past versions of Kilosort. dmin and dminx should be set to the median vertical and horizontal distances, respectively, between contacts as a starting point, it sounds like that would be 60 in your case.

Another parameter you could try adjusting is max_channel_distance to exclude additional template positions (in place of using those larger values for dmin and minx). For example, this is a zoomed in screenshot of where universal templates are placed using dmin and dminx = 120 as you have now: image It might be hard to see, but on every other row, the channels don't have any templates directly on them (only offset by 30um). If you change dmin and dminx to 60 this instead looks like: image which has more templates placed, but if you set max_channel_distance to 10 those will be reduced to just the templates directly on the channels: image Which might help with the memory issue.

jacobpennington commented 6 months ago

Also, would you be able to paste in the rest of the output you got from sorting, before the error happened? I'd like to see how many spikes and clusters were found, for example, to see if something else looks off. Allocating that much memory during that step is surprising. I would recommend trying sorting without drift correction as well, by setting nblocks = 0, that could introduce some artifacts and the kind of drift we're detecting shouldn't happen for an array like that anyway.

mikemanookin commented 6 months ago

Thank you, Jacob! This is really helpful. I will change the settings as you suggested and re-run the algorithm. Regarding the errors: unfortunately, I had to reboot my server, but I will be sure to pass those on to you if I run into the issue again. I really appreciate all of your help!!

mikemanookin commented 6 months ago

Hi Jacob. The updated parameters that you sent seem to be running well. I'm not running out of GPU memory and I'm getting a lot of good clusters from the sorting. Thank you so much for your help!

One quick follow up question. In the GUI, I can see gaps in the detected spikes in the areas between the electrodes. If I wanted to try to get rid of those gaps would I increase 'dmin' and 'dminx'? I've attached a picture of the output. Thanks!

IMG_1852

jacobpennington commented 6 months ago

That's great! Thanks for letting us know.

As for the spike gaps, I don't think changing dmin or dminx would affect that. If you set max_channel_distance = 10 like I suggested above, I would try sorting with max_channel_distance = None instead as long as it doesn't cause you to run out of memory, so that the spike detection templates are still placed in between contacts. So try changing that first if applicable.

Otherwise, it's possible the spikes in those gaps are far enough from the contacts that the amplitude is too low for the detection thresholds, so maybe try decreasing Th_universal and Th_learned by 1 or 2 each? @marius10p might have more insight on how to figure out what the issue is for that.

mikemanookin commented 6 months ago

@jacobpennington Thank you for the tips. I tried both techniques: setting max_channel_distance = None and dropping the thresholds. Neither approach seemed to make a difference (see image) and the processing time for each run was over 42 hours. I'm wondering if the code that estimates the x/y location of the spikes for our arrays is just biased in its estimate toward the electrodes. Screenshot from 2024-04-25 06-01-33

jacobpennington commented 6 months ago

Hmm okay. I'm not sure then, any ideas @marius10p?

jacobpennington commented 6 months ago

Sorry for the delay. I looked at this again with Marius, and we determined that the gaps between contacts are not surprising given the spacing. I.e. any spikes originating in those gaps are far enough from the nearest contact that they're not likely to be detected. I'm going to close this since the original issue was addressed, but if you run into more problems please let us know!

mikemanookin commented 6 months ago

Thank you!