SpikeInterface / spikeinterface

A Python-based module for creating flexible and robust spike sorting pipelines.
https://spikeinterface.readthedocs.io
MIT License
533 stars 187 forks source link

hoe to use GPU #3386

Closed bxy666666 closed 2 months ago

bxy666666 commented 2 months ago

I am using docker ,how to use docker to use GPU?

import os os.environ['HDF5_PLUGIN_PATH'] = 'D:\sjwlab\bxy' from pprint import pprint import time import spikeinterface import spikeinterface as si # import core only import spikeinterface.extractors as se import spikeinterface.preprocessing as spre import spikeinterface.sorters as ss import spikeinterface.postprocessing as spost import spikeinterface.qualitymetrics as sqm import spikeinterface.comparison as sc import spikeinterface.core as score import spikeinterface.exporters as sexp import spikeinterface.curation as scur import spikeinterface.widgets as sw import spikeinterface.full as si import matplotlib.pyplot as plt import numpy as np import os import pandas as pd from spikeinterface.widgets.isi_distribution import ISIDistributionWidget from spikeinterface.sortingcomponents.peak_detection import detect_peaks from spikeinterface.core.sortinganalyzer import create_sorting_analyzer from spikeinterface.exporters import export_to_phy import os import shutil import matplotlib.pyplot as plt import numpy as np import pandas as pd from spikeinterface import (extractors as se, preprocessing as spre, sorters as ss, postprocessing as spost, qualitymetrics as sqm, core as score, exporters as sexp, full as si) from spikeinterface.sortingcomponents.peak_detection import detect_peaks from spikeinterface.core.sortinganalyzer import create_sorting_analyzer from spikeinterface.exporters import export_to_phy

global_job_kwargs = dict(n_jobs=32, chunk_duration="1s") si.set_global_job_kwargs(**global_job_kwargs)

def process_data(data_location, data_name="data.raw.h5"):

base_folder_name = os.path.basename(os.path.normpath(data_location))

waveforms_folder_name = f"waveforms_folder"
phy_folder_name = f"phy_folder"
kilosort2_output_folder_name = f"kilosort2_output"
npy_waveform_folder = os.path.join(data_location, 'waveform_npy')  
os.makedirs(npy_waveform_folder, exist_ok=True) 

waveforms_folder_path = os.path.join(data_location, waveforms_folder_name)
phy_folder_path = os.path.join(data_location, phy_folder_name)
kilosort2_output_folder_path = os.path.join(data_location, kilosort2_output_folder_name)

recording = se.read_maxwell(os.path.join(data_location, data_name))
print("Recordings")
print(recording)

recording_f = si.bandpass_filter(recording, freq_min=300, freq_max=6000)
recording_cmr = si.common_reference(recording_f, reference='global', operator='median')

# 保存预处理后的录音数据
recording_preprocessed = recording_cmr.save(format='binary')

# 运行 Kilosort2 分拣器
sorting_KS2 = ss.run_sorter(
    sorter_name="kilosort4",
    recording=recording_preprocessed,
    extra_requirements=["numpy==1.26"],
    docker_image=True,
    verbose=True,
    output_folder=kilosort2_output_folder_path
)

# 提取波形
we_KS2 = si.extract_waveforms(
    recording=recording_preprocessed,
    sorting=sorting_KS2,
    folder=waveforms_folder_path,
    overwrite=None
)

# **保存每个unit的波形为 .npy 文件**
for unit_id in we_KS2.unit_ids:
    # 获取当前 unit 的波形数据
    waveforms = we_KS2.get_waveforms(unit_id)

    # 构造文件名,例如 "unit_1_waveforms.npy"
    npy_filename = f"unit_{unit_id}_waveforms.npy"
    npy_filepath = os.path.join(npy_waveform_folder, npy_filename)

    # 保存波形为 .npy 文件
    np.save(npy_filepath, waveforms)
    print(f"Unit {unit_id} 的波形数据已保存到 {npy_filepath}")

# 计算各种指标
amplitudes = spost.compute_spike_amplitudes(we_KS2)
unit_locations = spost.compute_unit_locations(we_KS2)
spike_locations = spost.compute_spike_locations(we_KS2)
correlograms, bins = spost.compute_correlograms(we_KS2)
similarity = spost.compute_template_similarity(we_KS2)
ISI = spost.compute_isi_histograms(we_KS2, window_ms=100.0, bin_ms=2.0, method="auto")
metric = spost.compute_template_metrics(we_KS2, include_multi_channel_metrics=True)

# 计算质量指标
qm_params = sqm.get_default_qm_params()
qm_params["presence_ratio"]["bin_duration_s"] = 1
qm_params["amplitude_cutoff"]["num_histogram_bins"] = 5
qm_params["drift"]["interval_s"] = 2
qm_params["drift"]["min_spikes_per_interval"] = 2
qm = sqm.compute_quality_metrics(we_KS2, qm_params=qm_params)

# 获取所有单元的脉冲列车数据并保存
spike_trains = {}
for unit_id in sorting_KS2.unit_ids:
    spike_train = sorting_KS2.get_unit_spike_train(unit_id, start_frame=None, end_frame=None)
    spike_trains[unit_id] = spike_train

# 保存到 Numpy 文件
np.save(os.path.join(data_location, 'aligned_spike_trains.npy'), spike_trains)

# 加载数据
loaded_spike_trains = np.load(os.path.join(data_location, 'aligned_spike_trains.npy'), allow_pickle=True).item()

# 保存到 CSV 文件
data = []
for unit_id, spike_train in spike_trains.items():
    for spike in spike_train:
        data.append([unit_id, spike])

df = pd.DataFrame(data, columns=['unit_id', 'spike_time'])
df.to_csv(os.path.join(data_location, 'aligned_spike_trains.csv'), index=False)

# 创建排序分析器
sorting_analyzer = create_sorting_analyzer(sorting=sorting_KS2, recording=recording_preprocessed)

# 计算所需数据
sorting_analyzer.compute(['random_spikes', 'waveforms', 'templates', 'noise_levels'])
_ = sorting_analyzer.compute('correlograms')
_ = sorting_analyzer.compute('spike_amplitudes')
_ = sorting_analyzer.compute('principal_components', n_components=5, mode="by_channel_local")

# 导出到 phy 格式
export_to_phy(sorting_analyzer=sorting_analyzer, output_folder=phy_folder_path)

# Plot and save images
sorting_analyzer.compute('unit_locations')
w2 = sw.plot_sorting_summary(sorting_analyzer, display=False, curation=True, backend="sortingview")
plt.savefig(os.path.join(data_location, f'sorting_summary_{suffix}.png'))
# 查看 SortingSummaryWidget 对象的所有属性和方法
print(dir(w2))

# 假设 handle_display_and_url 是用于生成和处理 URL 的函数
# 生成 URL (这里假设 handle_display_and_url 是正确实现的)
url = w2.url   # 确保你能从 w2 中获取 URL
# 保存 URL 到文本文件
url_file_path = os.path.join(data_location, f'plot_url_{suffix}.txt')
with open(url_file_path, 'w') as f:
    f.write(f"URL: {url}\n")

print(f"URL saved to {url_file_path}")
print(f"Processing complete for {suffix} in:", data_location)
# 计算掩码
keep_mask = (qm['isi_violations_ratio'] < 0.01) & \
            (qm['firing_rate'] >= 0.05) & \
            (qm['snr'] >= 5)
sorting_curated_auto = sorting_KS2.select_units(sorting_KS2.unit_ids[keep_mask])

# 删除缓存文件夹
cache_folder_path = 'C:\\Users\\baoxueying\\AppData\\Local\\Temp\\spikeinterface_cache'
delete_all_cache_files(cache_folder_path)

def delete_all_cache_files(base_cache_folder):

if os.path.exists(base_cache_folder):
    try:
        shutil.rmtree(base_cache_folder)
        print(f"已删除缓存文件夹及其所有内容: {base_cache_folder}")
    except Exception as e:
        print(f"删除缓存文件夹失败: {e}")
else:
    print(f"缓存文件夹不存在: {base_cache_folder}")

if name == 'main':

处理多个目录

data_locations = [

     ]

for data_location in data_locations:
    process_data(data_location)
zm711 commented 2 months ago

@bxy666666,

your code mixes a lot of terms and re-imports the same things. You run KS4 but call the variable KS2. You'll also do some WaveformExtractor steps and some SortingAnalyzer steps. So it is a bit hard to follow the code you're running. My first piece of advice would be consider a little cleanup. For example you could just do:

import spikeinterface as si # import core only
import spikeinterface.extractors as se
import spikeinterface.preprocessing as spre
import spikeinterface.sorters as ss
import spikeinterface.postprocessing as spost
import spikeinterface.qualitymetrics as sqm
import spikeinterface.comparison as sc
import spikeinterface.exporters as sexp
import spikeinterface.curation as scur
import spikeinterface.widgets as sw

and greatly cleanup your spikeinterface imports. You import the same thing many times. Next I would recommend factoring out the waveform extractor portions. If you are using the sorting_analyzer it will be better to stick with. It has some great improvements under the hood so you'll have an easier time moving forward.

Finally, how are you assessing GPU usage? What type of OS are you using? What CPU and GPU? KS4 will use CPU if you don't have a GPU that it can access so we need to know a bit more about your setup to know how to advise you more.

JoeZiminski commented 2 months ago

Also, you can try the functions here to check if torch can see your GPU. If so, I believe KS4 will be using the GPU for sure as it uses torch under the hood.

Note in your script all variable names refer to KS2 but KS4 is being run. For KS2, it requires an NVIDIA GPU and will error out if it is not available, so if it is running you know you are using the GPU.

bxy666666 commented 2 months ago

Ok ,thanks!