This is the original code implementation of the AXIAL framework proposed in the manuscript "AXIAL: Attention-based eXplainability for Interpretable Alzheimer's Localized Diagnosis using 2D CNNs on 3D MRI brain scans" by Gabriele Lozupone. [Paper]
This study presents an innovative method for Alzheimer's disease diagnosis using 3D MRI designed to enhance the explainability of model decisions. Our approach adopts a soft attention mechanism, enabling 2D CNNs to extract volumetric representations. At the same time, the importance of each slice in decision-making is learned, allowing the generation of a voxel-level attention map to produces an explainable MRI. To test our method and ensure the reproducibility of our results, we chose a standardized collection of MRI data from the Alzheimer's Disease Neuroimaging Initiative (ADNI). On this dataset, our method significantly outperforms state-of-the-art methods in (i) distinguishing AD from cognitive normal (CN) with an accuracy of 0.856 and Matthew's correlation coefficient (MCC) of 0.712, representing improvements of 2.4% and 5.3% respectively over the second-best, and (ii) in the prognostic task of discerning stable from progressive mild cognitive impairment (MCI) with an accuracy of 0.725 and MCC of 0.443, showing improvements of 10.2% and 20.5% respectively over the second-best. We achieved this prognostic result by adopting a double transfer learning strategy, which enhanced sensitivity to morphological changes and facilitated early-stage AD detection. With voxel-level precision, our method identified which specific areas are being paid attention to, identifying these predominant brain regions: the hippocampus, the amygdala, the parahippocampal, and the inferior lateral ventricles. All these areas are clinically associated with AD development. Furthermore, our approach consistently found the same AD-related areas across different cross-validation folds, proving its robustness and precision in highlighting areas that align closely with known pathological markers of the disease.
This repository contains code for preprocessing structural Magnetic Resonance Imaging (sMRI) data from the Alzheimer's Disease Neuroimaging Initiative (ADNI) dataset. The code converts the ADNI dataset into the Brain Imaging Data Structure (BIDS) format and applies preprocessing algorithms, including N4 bias field correction, MNI152 registration, and brain extraction. The data preparation is performed using the Clinica software.
Additionally, this repository provides the necessary code to train, validate, and test various deep learning models using PyTorch. Furthermore, the repository includes two explainability approaches one based on GradCAM and the other based on the attention mechanism.
3D Models: These approaches utilizes a 2D backbone to extract feature maps from slices and attention mechanisms to enables capturing slice-level features and their spatial relationships.
2D Models: This approach directly classifies each slice by attaching a classifier to the backbone. The final label for the entire image is determined using a majority voting approach based on the slice predictions.
Explainability: The first explainability approach proposed generates attention activation maps at the voxel level, highlighting the brain regions that are most important for the model's decision-making process. The second approach utilizes GradCAM to generate 3D heatmap. The 3D maps are then used to produce XAI metrics that helps to identify the most important brain regions for the model's decision-making process.
The repository aims to provide a comprehensive framework for sMRI preprocessing and deep learning analysis, enabling researchers to efficiently analyze ADNI data and develop advanced models for Alzheimer's disease detection and classification.
To use the code in this repository, follow these steps:
git clone https://github.com/GabrieleLozupone/AXIAL.git
pip install -r requirements.txt
Before performing preprocessing on the ADNI dataset, follow the steps below to prepare the necessary data.
Subscribe to the ADNI website at https://ida.loni.usc.edu/login.jsp.
Download the desired ADNI image collection. In the case of this work, the image collection name is "ADNI1 Complete 1Yr 1.5T".
On the ADNI website, click on "Download" and then select "Study Data".
Choose "ALL" to download all available data.
In the "Tabular Data (CSV format)" section, select all the files and download them.
Some CSV files in the clinical data may have a date at the end of their name. Remove the date from the file names to ensure compatibility with the preprocessing pipeline.
Install the Clinica software by following the instructions provided at https://aramislab.paris.inria.fr/clinica/docs/public/latest/Converters/ADNI2BIDS/. Clinica is a powerful tool that facilitates the conversion of ADNI data to the BIDS structure.
To convert the ADNI data to the BIDS structure, use the following command:
clinica convert adni-to-bids -m T1 DATASET_DIRECTORY CLINICAL_DATA_DIRECTORY BIDS_DIRECTORY
Replace DATASET_DIRECTORY
with the path to the downloaded ADNI dataset, CLINICAL_DATA_DIRECTORY
with the path to the
downloaded clinical data, and BIDS_DIRECTORY
with the desired output path for the BIDS-formatted dataset. The -m T1
option specifies that only MRI data with T1 weighting should be converted.
This section describes the preprocessing steps for the sMRI data.
To run the preprocessing pipeline on the sMRI data, execute the following command:
python data_preprocessing.py --bids_path /path/to/bids-dataset --n_proc 10 --checkpoint checkpoint.txt
Replace /path/to/bids-dataset
with the path to the BIDS-formatted dataset obtained from the data preparation steps.
The preprocessing pipeline includes MNI152 registration, brain extraction, and bias field correction with the N4
algorithm. The preprocessed images will be stored in the same path as the original images.
The n_proc
argument specifies the number of processes to be used for the preprocessing pipeline.
The checkpoint
argument specifies the path to the checkpoint file, which is used to keep track of the images that have
already been preprocessed. This allows the preprocessing pipeline to be interrupted and resumed later.
Please note that the preprocessing step is time-consuming and may take a significant amount of time to complete. A checkpoint mechanism is implemented to allow the preprocessing step to be interrupted and resumed later.
This script create a dataset.csv
file in the BIDS directory, which contains the path to the preprocessed images and
their corresponding labels. This file is used by the deep learning models to load the data.
This script also creates a dataset_conversion{num_months}.csv
file which contains the path to the preprocessed images
and the corresponding label for progression task (sMCI, pMCI).
To train a model, run the train.py
script. The configuration file used must is a YAML file: config.yaml.
python train.py
The script produces a results that is the average of the 5-fold cross-validation. The table below shows the results obtained on ADNI1 Complete 1Yr 1.5T
by the proposed model and the other models that can be tested with the framework.
Networks | AD vs. CN | sMCI vs. pMCI | ||||||
---|---|---|---|---|---|---|---|---|
ACC | SPE | SEN | MCC | ACC | SPE | SEN | MCC | |
Majority Voting (2D)(VGG16) | 0.804 | 0.897 | 0.688 | 0.605 | 0.614 | 0.601 | 0.629 | 0.229 |
Attention Transformer (TransformerConv3D)(VGG16) | 0.826 | 0.914 | 0.717 | 0.651 | 0.623 | 0.665 | 0.4873 | 0.238 |
AwareNet Diagnosis (AwareNet)(3D) | 0.832 | 0.875 | 0.778 | 0.659 | 0.4841 | 0.774 | 0.258 | 0.039 |
Ours (Axial3D)(VGG16) | 0.856 | 0.910 | 0.792 | 0.712 | 0.725 | 0.763 | 0.678 | 0.443 |
Attention-Guided Majority Voting (2D)(VGG16) | 0.843 | 0.894 | 0.780 | 0.683 | 0.633 | 0.624 | 0.643 | 0.266 |
Majority Voting 3D (3D)(VGG16) | 0.836 | 0.867 | 0.797 | 0.667 | 0.629 | 0.653 | 0.601 | 0.254 |
Note: The data splitting is performed by subjects to avoid data leakage problems. Make sure to adjust the configuration files according to your specific paths and requirements.
The attention_xai_analysis.py
file generates distributions of attention weights for the three planes (axial, coronal, sagittal) starting from three different models that output attention weights slice distribution and produces a result on 5-fold cross-validation to validate consistency. It also produces an explainable MRI on top of the template image MRI. This explainable image is used to produce XAI metrics that help identify which regions are more important for discerning an AD patient from a healthy one.
To run the explainability analysis, execute the following command:
python attention_xai_analysis.py
Results:
Brain Region | 𝑉𝑟 | 𝜇𝑟 | 𝜎𝑟 | 𝐴𝑚𝑎𝑥,𝑟 | 𝐴𝑚𝑖𝑛,𝑟 | 𝑃𝑟 |
---|---|---|---|---|---|---|
Hippocampus left | 1562 | 0.136 | 0.139 | 0.762 | 0.028 | 0.333 |
Hippocampus right | 1426 | 0.126 | 0.133 | 0.783 | 0.028 | 0.304 |
Parahippocampal left | 688 | 0.129 | 0.137 | 0.884 | 0.028 | 0.254 |
Parahippocampal right | 534 | 0.129 | 0.148 | 1.000 | 0.028 | 0.197 |
Amygdala left | 480 | 0.097 | 0.092 | 0.620 | 0.028 | 0.291 |
Amygdala right | 427 | 0.095 | 0.087 | 0.569 | 0.028 | 0.259 |
Inferior Lateral Ventricle right | 232 | 0.113 | 0.129 | 0.677 | 0.028 | 0.219 |
Inferior Lateral Ventricle left | 212 | 0.106 | 0.105 | 0.589 | 0.028 | 0.200 |
Cerebellum Gray Matter left | 208 | 0.035 | 0.005 | 0.052 | 0.028 | 0.003 |
Lateral Orbitofrontal left | 194 | 0.033 | 0.004 | 0.045 | 0.028 | 0.013 |
Fusiform right | 184 | 0.045 | 0.015 | 0.107 | 0.028 | 0.014 |
Lateral Orbitofrontal right | 140 | 0.034 | 0.004 | 0.046 | 0.028 | 0.009 |
Cerebellum Gray Matter right | 119 | 0.034 | 0.005 | 0.054 | 0.028 | 0.002 |
Fusiform left | 88 | 0.040 | 0.010 | 0.070 | 0.028 | 0.007 |
Entorhinal left | 16 | 0.034 | 0.003 | 0.041 | 0.030 | 0.005 |
Ventral Diencephalon left | 6 | 0.029 | 0.001 | 0.031 | 0.028 | 0.001 |
Entorhinal right | 2 | 0.033 | 0.003 | 0.036 | 0.031 | 0.001 |
The cam_xai_analysis.py
file performs similar tasks using CAM maps produced with GradCAM-like methods.
To run the explainability analysis, execute the following command:
python cam_xai_analysis.py
Results:
Visualization of mean 3D GradCAM++ map of entire dataset overlapped to MNI152 template with Axial3D (VGG16) | Visualization of mean 3D GradCAM++ map of entire dataset overlapped to MNI152 template with TransformerConv3D (VGG16) |
This project is licensed. Please review the License file for more information.
If you find this work useful for your research, please 🌟 our project and cite our paper :
@article{lozupone2024axial,
title={AXIAL: Attention-based eXplainability for Interpretable Alzheimer's Localized Diagnosis using 2D CNNs on 3D MRI brain scans},
author={Lozupone, Gabriele and Bria, Alessandro and Fontanella, Francesco and De Stefano, Claudio},
journal={arXiv preprint arXiv:2407.02418},
year={2024}
}