This work is published in Nature Communications (https://doi.org/10.1038/s41467-022-31037-5).
This repository contains the implementation of a deep learning framework that accomplishes 2 diagnostic steps to identify persons with normal cognition (NC), mild cognitive impairment (MCI), Alzheimer’s disease (AD) dementia, and dementia due to other etiologies (nADD).
We demonstrated that the framework compares favorably with the diagnostic performance of neurologists and neuroradiologists. To interpret the model, we conducted SHAP (SHapley Additive exPlanations) analysis on brain MRI and other features to reveal disease-specific patterns that correspond with expert-driven ratings and neuropathological findings.
The tool was developed using the following dependencies:
Please note that the dependencies may require Python 3.6 or greater. It is recommended to install and maintain all packages using conda
or pip
. For installation of GPU accelerated PyTorch, additional effort may be required. Please check the official websites of PyTorch and CUDA for detailed instructions.
Recommend to only clone the last version to avoid getting all commits during the development stage.
git clone --depth 1 https://github.com/vkola-lab/ncomms2022.git
The model_wrappers.py contains the interfaces for initializing, training, testing, saving, loading the model as well as creating SHAP interpretable heatmaps. See below for a basic example usage.
from model_wrappers import Multask_Wrapper
from utils import read_json
model = Multask_Wrapper(
tasks=['ADD', 'COG'], # a list of tasks to predict
device=1, # GPU device to use
main_config=read_json('config.json'), # general configuration for the experiment
task_config=read_json('task_config.json'), # task specific configurations
seed=1000
)
model.train()
thres = model.get_optimal_thres() # get optimal threshold using validation dataset
model.gen_score(['test'], thres) # apply optimal threshold on test dataset and cache predictions
The interface for training a fusion model or non-imaging model is similar to that of the CNN model. See below for a basic example usage.
from nonImg_model_wrappers import NonImg_Model_Wrapper, Fusion_Model_Wrapper
from utils import read_json
model = NonImg_Model_Wrapper(
tasks=['ADD', 'COG'], # a list of tasks to predict
main_config=read_json('config.json'), # general configuration for the experiment
task_config=read_json('task_config.json'), # task specific configurations
seed=1000
)
model.train()
thres = model.get_optimal_thres() # get optimal threshold using validation dataset
model.gen_score(['test'], thres) # apply optimal threshold on test dataset and cache predictions
model = Fusion_Model_Wrapper(
tasks=['ADD', 'COG'], # a list of tasks to predict
main_config=read_json('config.json'), # general configuration for the experiment
task_config=read_json('task_config.json'), # task specific configurations
seed=1000
)
model.train()
thres = model.get_optimal_thres() # get optimal threshold using validation dataset
model.gen_score(['test'], thres) # apply optimal threshold on test dataset and cache predictions
Note:
Since the gen_score method has already saved the raw predictions in csv, the evaluation pipeline just needs to look for those information from the corresponding experimental folders under tb_log. Mean, std or 95% confidence intervals are estimated using multiple independent experiments, for instance, from five-fold cross validation.
from performance_eval import generate_roc, generate_pr
generate_roc(
csv_files, # list[csvfiles] produced from "gen_score". Mean and std are estimated from cross-validation experiments
positive_label,
color,
out_file
)
generate_pr(
csv_files, # list[csvfiles] produced from "gen_score". Mean and std are estimated from cross-validation experiments
positive_label,
color,
out_file
)
The performance table contain accuracy, sensitivity, specificity, F-1, MCC for different tasks.
from performance_eval import perform_table
perform_table(
csv_files, # list[csvfiles] produced from "gen_score". Mean and std are estimated from cross-validation experiments
output_name # any name for the output csv file that contains metric information
)
from performance_eval import crossValid_cm(csv_files, stage)
crossValid_cm(
csv_files, # list[csvfiles] produced from "gen_score". Mean and std are estimated from cross-validation experiments
stage # if stage='test', confusion matrix for the test dataset will be generated
)
This full package compiled ROC/PR, performance table and confusion matrix all together.
from performance_eval import whole_eval_package
whole_eval_package(model_name, 'test') # evaluate on NACC testing set
whole_eval_package(model_name, 'OASIS') # evaluate on OASIS dataset
The shap_mid method will load pretrained weights first and then generates the SHAP interpretable saliency map for a specific middle layer over all instances.
model = Multask_Wrapper( # instantiate an already trained model
tasks=['ADD', 'COG'],
device=1,
main_config=read_json('config.json'),
task_config=read_json('task_config.json'),
seed=1000
)
model.shap_mid(
task_idx=0, # if task_idx == 0, the shap analysis will be about the ADD task (tasks[task_idx])
path='somewhere/', # where you want to save the generated shap numpy array
file='test.csv', # shap will be genareted on each case from this file
layer='block2conv' # which layer of the model that you want to interpret
)
For more details, please see the SHAP
The shap method will initialize corresponding SHAP explainer for various models, including XGBoost, CatBoost, Random Forest, Decision Tree, Support Vector Machine, Nearest Neighbor, Multi-layer Perceptron. See below for an example.
model = NonImg_Model_Wrapper(
tasks=['ADD', 'COG'], # a list of tasks to predict
main_config=read_json('config.json'), # general configuration for the experiment
task_config=read_json('task_config.json'), # task specific configurations
seed=1000
)
model.train()
thres = model.get_optimal_thres() # get optimal threshold using validation dataset
model.gen_score(['test'], thres) # apply optimal threshold on test dataset and cache predictions
shap_values, _ = model.shap("test_shap") # get shap values for all features over instances from test dataset
Please find the scripts used for plotting from the FigureTable/ folder.
To follow the data distribution policy from different study centers, we provided guidance on accessing and processing meta information instead of sharing the data within this repo. The meta data contains demographic information, medical history, neuropsychological tests, and functional questionaires. Please refer to our paper for a complete list of the features included.
We collected and organized meta data from 8 cohorts in the folder structure as below:
lookupcsv
│
├── raw_tables # inside raw_tables, you should save the directly-downloaded tables.
│ ├── NACC_ALL # within each folder, there is a readme file to guide the user to access and dowload data
│ │ ├── readme.txt
│ ├── ADNI
│ ├── OASIS
│ ├── AIBL
│ ├── FHS
│ └── ...
│
├── derived_tables # inside raw_tables, you should save the directly-downloaded tables.
│ ├── NACC_ALL # within each folder, there is a readme file to guide the user to run processing scripts that we provided
│ │ ├── readme.txt
│ ├── ADNI
│ ├── OASIS
│ ├── AIBL
│ ├── FHS
│ └── ...
│
├── dataset_table # this is where the final meta table is saved
│ ├── NACC_ALL
│ ├── ADNI
│ ├── OASIS
│ ├── AIBL
│ ├── FHS
│ └── ...
│
├── CrossValid # concate meta tables from dataset_table/ and then split NACC into train, valid, test
│ ├── cross0 # different cross contains different split, see our paper for more details on how the split was done
│ │ ├── train.csv
│ │ ├── valid.csv
│ │ ├── test.csv
│ │ ├── OASIS.csv
│ │ ├── exter_test.csv
│ ├── cross1
│ ├── cross2
│ ├── cross3
│ └── cross4
└── ...
To prepare for the meta data, (1) download data from offical data portals using the guidance from readme and save those in raw_tables (2) use the scripts provided in derived_tables to produce intermediate outcome (3) use the scripts provided in dataset_table to produce the final meta table using the information from both raw_tables and derived_tables (4) concate and split data for cross-validation
The pipeline for MRI processing is available in MRI_process/pipeline.sh. There are 4 sample de-identified and processed MRI scans in demo/mri/.
We also provide a demo script (demo_inference.py) to demonstrate how to generate inference on other data instances using pretrained CNN weights.
python demo_inference.py
Running the command above will produce a csv table under demo/ folder which contains the model's predictions on those 4 MRI scans from demo/mri/.