MouseLand / Kilosort

Fast spike sorting with drift correction
https://kilosort.readthedocs.io/en/latest/
GNU General Public License v3.0
484 stars 248 forks source link

about Kilosort4 #778

Closed bxy666666 closed 1 month ago

bxy666666 commented 2 months ago

Hello, I have some new questions about the Kilosort4. Questions below are all based on Kilosort4.

The number of neurons detected across different organoids with default parameters is consistent across chips. We suspect this might be due to overly high sensitivity. Will discontinuous data have a significant impact on neuron classification?

jacobpennington commented 2 months ago

@bxy666666 Can you please attach some screenshots of Phy results or log statements to show examples of the number of neurons detected (total and good units)? Are you saying it's the exact same number every time?

bxy666666 commented 2 months ago

您能否附上一些 Phy 结果或 log 语句的屏幕截图,以展示检测到的神经元数量(总单位和良好单位)的示例?你是说每次都是完全相同的数字吗?

This is my code. Imet a serious problem. I hope your help. When dealing different files, the number of units are the same. That's not right.

import os from pprint import pprint import time import spikeinterface import spikeinterface as si # import core only import spikeinterface.extractors as se import spikeinterface.preprocessing as spre import spikeinterface.sorters as ss import spikeinterface.postprocessing as spost import spikeinterface.qualitymetrics as sqm import spikeinterface.comparison as sc import spikeinterface.core as score import spikeinterface.exporters as sexp import spikeinterface.curation as scur import spikeinterface.widgets as sw import spikeinterface.full as si import matplotlib.pyplot as plt import numpy as np import os import pandas as pd from spikeinterface.widgets.isi_distribution import ISIDistributionWidget from spikeinterface.sortingcomponents.peak_detection import detect_peaks from spikeinterface.core.sortinganalyzer import create_sorting_analyzer from spikeinterface.exporters import export_to_phy from datetime import datetime, timedelta global_job_kwargs = dict(n_jobs=1, chunk_duration="1s") si.set_global_job_kwargs(**global_job_kwargs)

def process_recording(recording, data_location, suffix):

Preprocessing

recording_f = si.bandpass_filter(recording, freq_min=300, freq_max=6000) recording_cmr = si.common_reference(recording_f, reference='global', operator='median') recording_preprocessed = recording_cmr.save(format='binary') kilosort2_params = ss.get_default_sorter_params('kilosort4') print("Updated Kilosort4 params:", kilosort2_params)

Run Kilosort4

sorting_KS2 = ss.run_sorter(sorter_name="kilosort2", recording=recording_preprocessed, extra_requirements=["numpy==1.26"], docker_image=True, verbose=True)

Extract waveforms

we_KS2 = si.extract_waveforms(recording_preprocessed, sorting_KS2, os.path.join(data_location, f'waveformsfolder{suffix}'), overwrite=None)

Save waveforms as .npy files

npy_waveform_folder = os.path.join(data_location, f'waveformnpy{suffix}') os.makedirs(npy_waveform_folder, exist_ok=True) for unit_id in we_KS2.unit_ids: waveforms = we_KS2.get_waveforms(unit_id) npyfilename = f"unit{unit_id}_waveforms.npy" npy_filepath = os.path.join(npy_waveform_folder, npy_filename) np.save(npy_filepath, waveforms) print(f"Unit {unit_id} waveforms saved to {npy_filepath}")

Compute metrics

amplitudes = spost.compute_spike_amplitudes(we_KS2) unit_locations = spost.compute_unit_locations(we_KS2) spike_locations = spost.compute_spike_locations(we_KS2) correlograms, bins = spost.compute_correlograms(we_KS2) similarity = spost.compute_template_similarity(we_KS2) ISI = spost.compute_isi_histograms(we_KS2, window_ms=100.0, bin_ms=2.0, method="auto") metric = spost.compute_template_metrics(we_KS2, include_multi_channel_metrics=True) metric_names = spost.get_template_metric_names()

print(we_KS2.get_available_extension_names()) waveform_folder = os.path.join(data_location, f'waveformsfolder{suffix}') if not os.path.isdir(waveform_folder): print(f"Waveform folder does not exist: {waveform_folder}") return

we_loaded = si.load_waveforms(waveform_folder) print(we_loaded.get_available_extension_names()) print(we_loaded.get_available_extension_names())

Compute quality metrics

qm_params = sqm.get_default_qm_params() qm_params["presence_ratio"]["bin_duration_s"] = 1 qm_params["amplitude_cutoff"]["num_histogram_bins"] = 5 qm_params["drift"]["interval_s"] = 2 qm_params["drift"]["min_spikes_per_interval"] = 2 qm = sqm.compute_quality_metrics(we_KS2, qm_params=qm_params) print(f"Quality metrics for {suffix}:", qm)

Save spike trains

spike_trains = {} for unit_id in sorting_KS2.unit_ids: spike_train = sorting_KS2.get_unit_spike_train(unit_id, start_frame=None, end_frame=None) spike_trains[unit_id] = spike_train

np.save(os.path.join(data_location, f'aligned_spiketrains{suffix}.npy'), spike_trains)

Load and check spike trains

loaded_spike_trains = np.load(os.path.join(data_location, f'aligned_spiketrains{suffix}.npy'), allow_pickle=True).item() print(f"Loaded spike train data type for {suffix}:", type(loaded_spike_trains)) print(f"Loaded spike train dimensions for {suffix}:", {k: np.shape(v) for k, v in loaded_spike_trains.items()})

Save to CSV

data = [] for unit_id, spike_train in spike_trains.items(): for spike in spike_train: data.append([unit_id, spike]) df = pd.DataFrame(data, columns=['unit_id', 'spike_time']) df.to_csv(os.path.join(data_location, f'aligned_spiketrains{suffix}.csv'), index=False)

Export to phy

sorting_analyzer = si.create_sorting_analyzer(sorting=sorting_KS2, recording=recording_preprocessed, format="memory") sorting_analyzer.compute(['random_spikes', 'waveforms', 'templates', 'noiselevels']) = sortinganalyzer.compute('correlograms') = sorting_analyzer.compute('spikeamplitudes') = sorting_analyzer.compute('principal_components', n_components=5, mode="by_channel_local")

phy_folder = os.path.join(data_location, f'phyfolder{suffix}')

os.makedirs(phy_folder, exist_ok=True)

si.export_to_phy(sorting_analyzer=sorting_analyzer, output_folder=os.path.join(data_location, f'phyfolder{suffix}'))

Plot and save images

sorting_analyzer.compute('unit_locations') w2 = sw.plot_sorting_summary(sorting_analyzer, display=False, curation=True, backend="sortingview") plt.savefig(os.path.join(data_location, f'sortingsummary{suffix}.png'))

查看 SortingSummaryWidget 对象的所有属性和方法

print(dir(w2))

url = w2.url # 确保你能从 w2 中获取 URL

url_file_path = os.path.join(data_location, f'ploturl{suffix}.txt') with open(url_file_path, 'w') as f: f.write(f"URL: {url}\n")

print(f"URL saved to {url_file_path}") print(f"Processing complete for {suffix} in:", data_location) w_rs = sw.plot_rasters(sorting_KS2) plt.savefig(os.path.join(datalocation, f'rasters{suffix}.png'))

w_pr = sw.plot_unit_presence(sorting_KS2) plt.savefig(os.path.join(data_location, f'unitpresence{suffix}.png')) def process_with_artifact_removal(recording, stim_times, recording_start_time, ms_after=1000):

list_of_artifacts = []

for stim_time in stim_times: stim_start_time = stim_time stim_end_time = stim_time + timedelta(milliseconds=ms_after) start_time_sec = (stim_start_time - recording_start_time).total_seconds()

list_of_artifacts.append(start_time_sec)

rec_segment = recording.time_slice(start_time=start_time_sec, end_time=start_time_sec + ms_after / 1000.0)
cleaned_segment = spre.remove_artifacts(rec_segment,
                                        list_triggers=[start_time_sec],
                                        ms_before=0, ms_after=ms_after)

cleaned_recording = spre.remove_artifacts(recording, list_triggers=list_of_artifacts, ms_before=0, ms_after=ms_after)

return cleaned_recording def split_recording_at_midpoint(recording, recording_start_time): sampling_rate = recording.get_sampling_frequency() total_duration_sec = recording.get_duration()

midpoint_sec = total_duration_sec / 2

midpoint_frame = int(midpoint_sec * sampling_rate)

total_frames = int(total_duration_sec * sampling_rate)

if not (0 <= midpoint_frame < total_frames): raise ValueError("Midpoint frame is out of bounds.")

rec_first_half = recording.frame_slice(start_frame=0, end_frame=midpoint_frame) rec_second_half = recording.frame_slice(start_frame=midpoint_frame, end_frame=total_frames)

return rec_first_half, rec_second_half def process_files(data_location, csv_file, recording_start_time):

Load stim times

stim_times_df = pd.read_csv(csv_file) stim_times_df['End Time'] = pd.to_datetime(stim_times_df['End Time']) stim_times = stim_times_df['End Time'].tolist() print("Stim Times:", stim_times)

data_name = "data.raw.h5" recording = si.read_maxwell(os.path.join(data_location, data_name))

cleaned_recording = process_with_artifact_removal(recording, stim_times, recording_start_time)

cleaned_recording_first_half, cleaned_recording_second_half = split_recording_at_midpoint(cleaned_recording, recording_start_time) process_recording(cleaned_recording_first_half, data_location, 'KS2-first_half') process_recording(cleaned_recording_second_half, data_location, 'KS2-second_half') List of file sets with paths and start times file_sets = [ { "data_location": , "csv_file": , "recording_start_time": },

]

Process each file set for file_set in file_sets: process_files( data_location=file_set["data_location"], csv_file=file_set["csv_file"], recording_start_time=file_set["recording_start_time"] ) cleaned_recording_first_half and cleaned_recording_second_half are the different part of my files, so it is different recordings right? I also have file_sets = [

{ "data_location": "D:\sjwlab\bxy\0828shang\24395\1card-3interval-near\game\", "csv_file": r"D:\sjwlab\bxy\0828shang\24395\1card-3interval-near\game\stim_region_times_20240828_105130.csv", "recording_start_time": datetime.strptime("2024/8/28 10:51:32", "%Y/%m/%d %H:%M:%S") }, { "data_location": "D:\sjwlab\bxy\game\24395\01\01-1\", "csv_file": r"D:\sjwlab\bxy\game\24395\01\01-1\stim_region_times_20240830_112443.csv", "recording_start_time": datetime.strptime("2024/8/30 11:24:46", "%Y/%m/%d %H:%M:%S") },

{ "data_location": "D:\sjwlab\bxy\game\24395\01\4.4-1\", "csv_file": r"D:\sjwlab\bxy\game\24395\01\01-2\stim_region_times_20240830_113653.csv", "recording_start_time": datetime.strptime("2024/8/30 11:36:56", "%Y/%m/%d %H:%M:%S") }, ]

jacobpennington commented 2 months ago

That code looks like it's using Kilosort2, not Kilosort4. You're also using SpikeInterface, which I can't really help with. Please try sorting the data with Kilosort4, without using SpikeInterface, then let us know if you still encounter problems.