MouseLand / Kilosort

Fast spike sorting with drift correction
https://kilosort.readthedocs.io/en/latest/
GNU General Public License v3.0
478 stars 248 forks source link

BUG: error when using a custom probe dictionay - cast float64 to int32 when subtracting shifts from yc during template extraction #681

Closed DanEgert closed 6 months ago

DanEgert commented 6 months ago

Describe the issue:

my sort aborted during template extraction i'm using a custom probe dictionary. there might be an issue with the data type i defined the xc and yc coordinates in.

---> 59 yp[:,1] -= shifts 61 xp = torch.from_numpy(xp).to(device) 62 yp = torch.from_numpy(yp).to(device)

UFuncTypeError: Cannot cast ufunc 'subtract' output from dtype('float64') to dtype('int32') with casting rule 'same_kind'

Reproduce the bug:

# defining my channel map coordinates:
# y coords
# Create a pattern
# Define the first sequence starting from 240 and decreasing by 30
array1 = np.arange(240, 29, -30)

# Define the second sequence starting from 240-15 and decreasing by 30
array2 = np.arange(240 - 15, 29 - 15, -30)

#repeat
repeated_array1 = np.tile(array1, 16)
repeated_array2 = np.tile(array2, 16)

# Reshape arrays into shape (16, 8)
reshaped_array1 = repeated_array1.reshape(16, 8)
reshaped_array2 = repeated_array2.reshape(16, 8)

# Interdigitate both sequences
yc = np.empty((16*2, 8), dtype=int)
yc[::2] = reshaped_array1
yc[1::2] = reshaped_array2

yc = yc.flatten()

# x coords
numbers=[]
# Iterate over the range from 100 to 900 with a step of 100
for i in range(100, 3201, 100):
    # Repeat each value 8 times before increasing
    numbers.extend([i] * 8)

# Convert the list to a NumPy array
xc = np.array(numbers)

...

#launching kilosort
ops, st, clu, tF, Wall, similar_templates, is_ref, est_contam_rate = \
    run_kilosort(settings=settings, probe=probe, filename=recording_file)

Error message:

UFuncTypeError                            Traceback (most recent call last)
Cell In[9], line 4
      1 from kilosort import run_kilosort
      3 ops, st, clu, tF, Wall, similar_templates, is_ref, est_contam_rate = \
----> 4     run_kilosort(settings=settings, probe=probe, filename=recording_file)

File H:\Anaconda\envs\kilosort\lib\site-packages\kilosort\run_kilosort.py:147, in run_kilosort(settings, probe, probe_name, filename, data_dir, file_object, results_dir, data_dtype, do_CAR, invert_sign, device, progress_bar, save_extra_vars)
    144 io.save_ops(ops, results_dir)
    146 # Sort spikes and save results
--> 147 st, tF, _, _ = detect_spikes(ops, device, bfile, tic0=tic0,
    148                              progress_bar=progress_bar)
    149 clu, Wall = cluster_spikes(st, tF, ops, device, bfile, tic0=tic0,
    150                            progress_bar=progress_bar)
    151 ops, similar_templates, is_ref, est_contam_rate, kept_spikes = \
    152     save_sorting(ops, results_dir, st, clu, tF, Wall, bfile.imin, tic0,
    153                  save_extra_vars=save_extra_vars)

File H:\Anaconda\envs\kilosort\lib\site-packages\kilosort\run_kilosort.py:398, in detect_spikes(ops, device, bfile, tic0, progress_bar)
    396 tic = time.time()
    397 print(f'\nExtracting spikes using templates')
--> 398 st0, tF, ops = spikedetect.run(ops, bfile, device=device, progress_bar=progress_bar)
    399 tF = torch.from_numpy(tF)
    400 print(f'{len(st0)} spikes extracted in {time.time()-tic : .2f}s; ' + 
    401         f'total {time.time()-tic0 : .2f}s')

File H:\Anaconda\envs\kilosort\lib\site-packages\kilosort\spikedetect.py:188, in run(ops, bfile, device, progress_bar)
    186     print('Re-computing universal templates from data.')
    187     # Determine templates and PC features from data.
--> 188     ops['wPCA'], ops['wTEMP'] = extract_wPCA_wTEMP(
    189         ops, bfile, nt=ops['nt'], twav_min=ops['nt0min'], 
    190         Th_single_ch=ops['settings']['Th_single_ch'], nskip=25,
    191         device=device
    192         )
    193 else:
    194     print('Using built-in universal templates.')

File H:\Anaconda\envs\kilosort\lib\site-packages\kilosort\spikedetect.py:55, in extract_wPCA_wTEMP(ops, bfile, nt, twav_min, Th_single_ch, nskip, device)
     53 i = 0
     54 for j in range(0, bfile.n_batches, nskip):
---> 55     X = bfile.padded_batch_to_torch(j, ops)
     57     clips_new = extract_snippets(X, nt=nt, twav_min=twav_min,
     58                                  Th_single_ch=Th_single_ch, device=device)
     60     nnew = len(clips_new)

File H:\Anaconda\envs\kilosort\lib\site-packages\kilosort\io.py:708, in BinaryFiltered.padded_batch_to_torch(self, ibatch, ops, return_inds)
    706 else:
    707     X = super().padded_batch_to_torch(ibatch)
--> 708     return self.filter(X, ops, ibatch)

File H:\Anaconda\envs\kilosort\lib\site-packages\kilosort\io.py:686, in BinaryFiltered.filter(self, X, ops, ibatch)
    684 if self.whiten_mat is not None:
    685     if self.dshift is not None and ops is not None and ibatch is not None:
--> 686         M = get_drift_matrix(ops, self.dshift[ibatch], device=self.device)
    687         #print(M.dtype, X.dtype, self.whiten_mat.dtype)
    688         X = (M @ self.whiten_mat) @ X

File H:\Anaconda\envs\kilosort\lib\site-packages\kilosort\preprocessing.py:59, in get_drift_matrix(ops, dshift, device)
     57 xp = np.vstack((ops['probe']['xc'],ops['probe']['yc'])).T
     58 yp = xp.copy()
---> 59 yp[:,1] -= shifts
     61 xp = torch.from_numpy(xp).to(device)
     62 yp = torch.from_numpy(yp).to(device)

UFuncTypeError: Cannot cast ufunc 'subtract' output from dtype('float64') to dtype('int32') with casting rule 'same_kind'

Version information:

kilosort4 v4.0.6

Context for the issue:

No response

Experiment information:

No response

DanEgert commented 6 months ago

This error disappeared when I first saved the probe dictionary to .json and loaded it, before passing it to run_kilosort. probe = { 'chanMap': chanMap, 'xc': xc, 'yc': yc, 'kcoords': kcoords, 'n_chan': n_chan } save_probe(probe, 'dan_probe_ks4.json') load_probe = load_probe('dan_probe_ks4.json') ops, st, clu, tF, Wall, similar_templates, is_ref, est_contam_rate = \ run_kilosort(settings=settings, probe=load_probe, filename=recording_file)

jacobpennington commented 6 months ago

Thanks, that's helpful.

jacobpennington commented 6 months ago

This is fixed now with the latest code changes, I'll have it added to pip soon.