PridaLab / rippl-AI

A open toolbox of several machine learning approaches for sharp-wave ripple detection
26 stars 4 forks source link

rippl-AI

rippl-AI is an open toolbox of Artifical Intelligence (AI) resources for detection of hippocampal neurophysiological signals, in particular sharp-wave ripples (SWR). This toolbox offers multiple successful plug-and-play machine learning (ML) models from 5 different architectures (1D-CNN, 2D-CNN, LSTM, SVM and XGBoost) that are ready to use to detect SWRs in hippocampal recordings. Moreover, there is an additional package that allows easy re-training, so that models are updated to better detect particular features of your own recordings. More details in Navas-Olive, Rubio, et al. Commun Biol 7, 211 (2024)!

Description

Sharp-wave ripples

Sharp-wave ripples (SWRs) are transient fast oscillatory events (100-250Hz) of around 50ms that appear in the hippocampus, that had been associated with memory consolidation. During SWRs, sequential firing of ensembles of neurons are replayed, reactivating memory traces of previously encoded experiences. SWR-related interventions can influence hippocampal-dependent cognitive function, making their detection crucial to understand underlying mechanisms. However, existing SWR identification tools mostly rely on using spectral methods, which remain suboptimal.

Because of the micro-circuit properties of the hippocampus, CA1 SWRs share a common profile, consisting of a ripple in the stratum pyramidale (SP), and a sharp-wave deflection in stratum radiatum that reflects the large excitatory input that comes from CA3. Yet, SWRs can extremely differ depending on the underlying reactivated circuit. This continuous recording shows this variability:

Example of several SWRs

Artificial intelligence architectures

In this project, we take advantage of supervised machine learning approaches to train different AI architectures so they can unbiasedly learn to identify signature SWR features on raw Local Field Potential (LFP) recordings. These are the explored architectures:

Convolutional Neural Networks (CNNs)

Convolutional Neural Networks

Support Vector Machine (SVM)

Support Vector Machine

Long-Short Term Memory Recurrent Neural Networks (LSTM)

Long-Short Term Memory Recurrent Neural Networks

Extreme-Gradient Boosting (XGBoost)

Extreme-Gradient Boosting

The toolbox

This toolbox contains three main blocks: detection, re-training and exploration. These three packages can be used jointly or separatedly. We will proceed to describe each of their purpose and usage.

Detection

In previous works (Navas-Olive, Amaducci et al, 2022), we demonstrated that using feature-based algorithms to detect electrophysiological events, such as SWRs, had several advantages:

In this toolbox, we widen the machine learning spectrum, by offering multiple plug-and-play models, from very different AI architectures: 1D-CNN, 2D-CNN, LSTM, SVM and XGBoost. We performed an exhaustive parametric search to find different architecture solutions (i.e. models) that achieve:

This respository contains the best five models from each of these five architectures. These models are already trained using mice data, and can be found in the optimized_models/ folder.

The rippl_AI python module contains all the necessary functions to easily use any model to detect SWRs. Additionally, we also provide some auxiliary functions in the aux_fcn module, that contains useful code to process LFP and evaluate performance detection.

Moreover, several usage examples of all functions can be found in the examples_detection.ipynb python notebook.

rippl_AI.predict()

The python function predict(LFP, sf, arch='CNN1D', model_number=1, channels=np.arange(8), d_sf=1250) of the rippl_AI module computes the SWR probability for a give LFP.

In the figure below, you can see an example of a high-density LFP recording (top) with manually labeled data (gray). The objective of these models is to generate an output signal that most similarly matches the manually labeled signal. The output of the uploaded optimized models can be seen in the bottom, where outputs go from 0 (low probability of SWR) to 1 (high probability of SWR) for each LFP sample.

Detection method

The rippl_AI.predict() input and output variables are:

rippl_AI.get_intervals()

The python function get_intervals(SWR_prob, LFP_norm=None, sf=1250, win_size=100, threshold=None, file_path=None, merge_win=0) of the rippl_AI module takes the output of rippl_AI.predict() (i.e. the SWR probability), and identifies SWR beginnings and ends by stablishing a threshold. In the figure below, you can see how the threshold can decisevely determine what events are being detected. For example, lowering the threshold to 0.5 would have result in XGBoost correctly detecting the first SWR, and the 1D-CNN detecting the sharp-wave that has no ripple.

Detection method

aux_fcn.manual_curation()

The python function aux_fcn.manual_curation(events, data, file_path, win_size=100, gt_events=None, sf=1250) of the aux_fcn module allows doing a manual curation of the detected events. It displays an interactive GUI to manually select/discard the events.

Example of manual curation function

Use cases:

  1. If no GT events are provided, a the detected events will be provided, you can select which ones you want to keep (highligted in green) and which ones to discard (in red)
  2. If GT events are provided, true positive detections (TP) will be displayed in green. If for any reason you want to discard correct detections, they will be displayed in yellow

aux_fcn.plot_all_events()

The python function aux_fcn.plot_all_events(t_events, lfp, sf, win=0.1, title='', savefig='') of the aux_fcn module plots all events in a single plot. It can be used as a fast summary/check after detection and/or curation.

aux_fcn.process_LFP()

The python function process_LFP(FP, sf, d_sf, channels) of the aux_fcn module processes the LFP before it is input to the algorithm. It downsamples LFP to d_sf, and normalizes each channel separately by z-scoring them.

aux_fcn.interpolate_channels()

The python function interpolate_channels(LFP, channels) of the aux_fcn module allows creating more intermediate channels using interpolation.

Because these models best performed using a richer spatial profile, all combinations of architectures and model_numbers work with 8 channels. There is only one exception, for architecture = 2D-CNN with models = {3, 4, 5}, that needs to have 3 channels. However, some times it's not possible to get such number of channels in the pyramidal layer, like when using linear probes (only 2 oe 3 channels fit in the pyramidal layer) or tetrodes (there are 4 recording channels). For this, we developed this interpolation function, that creates new channels between any pair of your recording channels. Using this approach, we can successfully use the already built algorithms with an equally high performance.

aux_fcn.get_performance()

The python function get_performance(predictions, true_events, threshold=0, exclude_matched_trues=False, verbose=True) of the aux_fcn module computes several performance metrics:

Therefore, this function can be used only when some ground truth (i.e. events that we are considering the truth) is given. In order to check if a true event has been predicted, it computes the Intersection over Union (IoU). This index metric measures how much two intervals intersect with respect of the union of their size. So if pred_events = [[2,3], [6,7]] and true_events = [[2,4]],[8,9]], then we would expect that the IoU(pred_events[0], true_events[0]) > 0, while the rest will be zero.

Re-training

Here, we provide a unique toolbox to easily re-train models and adapt them to new datasets. These models have been selected because their architectural parameters are best fit to look for electrophysiological high-frequency events. So both if you are interested in finding SWRs or other electrophysiological events, these toolbox offers you the possility to skip all the parametric search and parameter tuning just by running this scripts. The advantages of the re-training module are:

rippl_AI.retrain_model()

The python function rippl_AI.retrain_model(train_data, train_GT, test_data, test_GT, arch, parameters=None, save_path=None, d_sf=1250, merge_win=0) of the rippl_AI module re-trains the best model of a given architecture to re-learn the optimal features to detect the new ground truth events annotated in the ground truth events.

Usage examples can be found in the examples_retraining.ipynb python notebook.

Exploration

Finally, as a further explotation of this toolbox, we also offer an exploration module, in which you can create your own model. In the examples_explore folder, you can see how different architectures can be modified by multiple parameters to create infinite number of other models, that can be better adjusted to the need of your desired events. For example, if you are interested in lower frequency events, such as theta cycles, this exploratory module will be of utmost convenience to find an AI architecture that better adapts to the need of your research. Here, we specify the most common parameters to explore for each architecture:

1D-CNN

2D-CNN

LSTM

SVM

XGBoost

Enviroment setup

  1. Install miniconda, following the tutorial: https://docs.conda.io/en/latest/miniconda.html
  2. Launch the anaconda console, typing anaconda promp in the windows/linux search bar.
  3. In the anaconda prompt, create a conda environment (e.g. ripple_AI_env):
    conda create -n rippl_AI_env python=3.9.15
  4. This will create a enviroment in your miniconda3 enviroments folder, usually: C:\Users\<your_user>\miniconda3\envs
  5. Check that the enviroment rippl_AI_env has been created by typing:
    conda env list
  6. Activate the enviroment with: conda activate rippl_AI_env In case you want to launch the scripts from the command prompt. If you are using Visual Studio Code, you need to select the python interpreter rippl_AI_env
  7. Next step after activating the enviroment, is to install every necessary python package:
    conda install pip
    pip install tensorflow==2.11 keras==2.11 xgboost==1.6.1 imblearn numpy matplotlib pandas scipy
    pip install -U scikit-learn==1.1.2

    To download the lab data from figshare (not normalized, sampled with the original frequency of 30 000 Hz):

    git clone https://github.com/cognoma/figshare.git
    cd figshare
    python setup.py install

    The package versions compatible with the toolbox are: