HayoonSong / Transfer-learning-for-EEG-MI-classification-cross-subjects

4 stars 0 forks source link

Transfer-learning-for-EEG-MI-classification-cross-subjects

In this study, we can improve classification accuracy of motor imagery using EEGNet. To overcome the lack of subject-specific data, transfer learning-based approaches are increasingly integrated into motor imagery systems using pre-existing information from other subjects (source domain) to facilitate the calibration for a new subject (target domain) through a set of shared features among individuals(Collazos-Huertas, 2021).

Paper

System dependencies

Project Architecture

.
├── data_generator              # dataset generator
|   └── data_preprocessing.py   # data genertor for target and source data
├── model                       # tensorflow model files 
|   └── EEGNet.py               # EEGNet
├── trainer                     # tensorflow trianer files
|   ├── Train.py                # super trainer class
|   ├── baseline_train.py       # baseline trainer class with EEGNet
|   ├── pretraining_train.py    # pre-train trainer class with source data
|   └── finetuning_train.py     # finetuning trainer class with pre-trained EEGNet
├── visualizer.py               # bar chart and confusion matrix
└── utils.py                    # a series of tools used in this repo

Installation

To use this codebase, simply clone the Github repository and install the requirements like this:

git clone https://github.com/HayoonSong/Transfer-learning-for-EEG-MI-classification-across-subjects
cd Transfer-learning-for-EEG-MI-classification-across-subjects/src
pip install -r requirements.txt

Dataset

We evaluated our model using the BCI Compteition IV-2a datasets published in 2008.
The Cross-subejct transfer learning introduced the idea of separating total data into two subsets:

To separate the target data and source data from the combined train data and evaluation data:

python data_generator/data_preprocessing.py --data_dir ../data/

Model

We use EEGNet
Original authors have uploaded their code here https://github.com/vlawhern/arl-eegmodels

Baseline

To compare the performance of Transfer Learning model and Traditional Neural Network,
run the baseline.py script like this:

python trainer/baseline_train.py \
    --data_dir ../data \
    --ckpt_dir ../ckpt \
    --result_dir ../result

Pre-training

To pre-train the transformer, run the pretraining_train.py script like this:

python trainer/pretraining_train.py \
    --data_dir ../data \
    --ckpt_dir ../ckpt

Fine-tuning

To fine-tune the pre-trained transformer, run the finetuning_train.py script like this:

python trainer/finetuning_train.py \
    --data_dir ../data \
    --ckpt_dir ../ckpt \
    --result_dir ../result

Results