GabrieleLozupone / AXIAL

This is a code implemention of the diagnosis and XAI framework proposed in the paper "Attention-based eXplainability for Interpretable Alzheimer's Localized Diagnosis using 2D CNNs on 3D MRI brain scans".
Other
29 stars 3 forks source link

[AXIAL]

This is the original code implementation of the AXIAL framework proposed in the manuscript "AXIAL: Attention-based eXplainability for Interpretable Alzheimer's Localized Diagnosis using 2D CNNs on 3D MRI brain scans" by Gabriele Lozupone. [Paper]

Abstract

This study presents an innovative method for Alzheimer's disease diagnosis using 3D MRI designed to enhance the explainability of model decisions. Our approach adopts a soft attention mechanism, enabling 2D CNNs to extract volumetric representations. At the same time, the importance of each slice in decision-making is learned, allowing the generation of a voxel-level attention map to produces an explainable MRI. To test our method and ensure the reproducibility of our results, we chose a standardized collection of MRI data from the Alzheimer's Disease Neuroimaging Initiative (ADNI). On this dataset, our method significantly outperforms state-of-the-art methods in (i) distinguishing AD from cognitive normal (CN) with an accuracy of 0.856 and Matthew's correlation coefficient (MCC) of 0.712, representing improvements of 2.4% and 5.3% respectively over the second-best, and (ii) in the prognostic task of discerning stable from progressive mild cognitive impairment (MCI) with an accuracy of 0.725 and MCC of 0.443, showing improvements of 10.2% and 20.5% respectively over the second-best. We achieved this prognostic result by adopting a double transfer learning strategy, which enhanced sensitivity to morphological changes and facilitated early-stage AD detection. With voxel-level precision, our method identified which specific areas are being paid attention to, identifying these predominant brain regions: the hippocampus, the amygdala, the parahippocampal, and the inferior lateral ventricles. All these areas are clinically associated with AD development. Furthermore, our approach consistently found the same AD-related areas across different cross-validation folds, proving its robustness and precision in highlighting areas that align closely with known pathological markers of the disease.

Framework overview

This repository contains code for preprocessing structural Magnetic Resonance Imaging (sMRI) data from the Alzheimer's Disease Neuroimaging Initiative (ADNI) dataset. The code converts the ADNI dataset into the Brain Imaging Data Structure (BIDS) format and applies preprocessing algorithms, including N4 bias field correction, MNI152 registration, and brain extraction. The data preparation is performed using the Clinica software.

Additionally, this repository provides the necessary code to train, validate, and test various deep learning models using PyTorch. Furthermore, the repository includes two explainability approaches one based on GradCAM and the other based on the attention mechanism.

  1. 3D Models: These approaches utilizes a 2D backbone to extract feature maps from slices and attention mechanisms to enables capturing slice-level features and their spatial relationships.

  2. 2D Models: This approach directly classifies each slice by attaching a classifier to the backbone. The final label for the entire image is determined using a majority voting approach based on the slice predictions.

  3. Explainability: The first explainability approach proposed generates attention activation maps at the voxel level, highlighting the brain regions that are most important for the model's decision-making process. The second approach utilizes GradCAM to generate 3D heatmap. The 3D maps are then used to produce XAI metrics that helps to identify the most important brain regions for the model's decision-making process.

The repository aims to provide a comprehensive framework for sMRI preprocessing and deep learning analysis, enabling researchers to efficiently analyze ADNI data and develop advanced models for Alzheimer's disease detection and classification.

Table of Contents

  1. Installation
  2. Data Preparation
  3. Preprocessing
  4. Deep Learning approaches
  5. Acknowledgement
  6. License

Installation

To use the code in this repository, follow these steps:

  1. Clone the repository:
git clone https://github.com/GabrieleLozupone/AXIAL.git
  1. Install the required dependencies:
pip install -r requirements.txt

Data Preparation

Before performing preprocessing on the ADNI dataset, follow the steps below to prepare the necessary data.

ADNI Data Download

  1. Subscribe to the ADNI website at https://ida.loni.usc.edu/login.jsp.

  2. Download the desired ADNI image collection. In the case of this work, the image collection name is "ADNI1 Complete 1Yr 1.5T".

Clinical Data Download

  1. On the ADNI website, click on "Download" and then select "Study Data".

  2. Choose "ALL" to download all available data.

  3. In the "Tabular Data (CSV format)" section, select all the files and download them.

Rename CSV Files

Some CSV files in the clinical data may have a date at the end of their name. Remove the date from the file names to ensure compatibility with the preprocessing pipeline.

Install Clinica Software

Install the Clinica software by following the instructions provided at https://aramislab.paris.inria.fr/clinica/docs/public/latest/Converters/ADNI2BIDS/. Clinica is a powerful tool that facilitates the conversion of ADNI data to the BIDS structure.

Convert ADNI to BIDS

To convert the ADNI data to the BIDS structure, use the following command:

clinica convert adni-to-bids -m T1 DATASET_DIRECTORY CLINICAL_DATA_DIRECTORY BIDS_DIRECTORY

Replace DATASET_DIRECTORY with the path to the downloaded ADNI dataset, CLINICAL_DATA_DIRECTORY with the path to the downloaded clinical data, and BIDS_DIRECTORY with the desired output path for the BIDS-formatted dataset. The -m T1 option specifies that only MRI data with T1 weighting should be converted.

Preprocessing

This section describes the preprocessing steps for the sMRI data.

Run Preprocessing Pipeline

To run the preprocessing pipeline on the sMRI data, execute the following command:

python data_preprocessing.py --bids_path /path/to/bids-dataset --n_proc 10 --checkpoint checkpoint.txt

Replace /path/to/bids-dataset with the path to the BIDS-formatted dataset obtained from the data preparation steps. The preprocessing pipeline includes MNI152 registration, brain extraction, and bias field correction with the N4 algorithm. The preprocessed images will be stored in the same path as the original images.

The n_proc argument specifies the number of processes to be used for the preprocessing pipeline.

The checkpointargument specifies the path to the checkpoint file, which is used to keep track of the images that have already been preprocessed. This allows the preprocessing pipeline to be interrupted and resumed later.

Please note that the preprocessing step is time-consuming and may take a significant amount of time to complete. A checkpoint mechanism is implemented to allow the preprocessing step to be interrupted and resumed later.

This script create a dataset.csv file in the BIDS directory, which contains the path to the preprocessed images and their corresponding labels. This file is used by the deep learning models to load the data.

This script also creates a dataset_conversion{num_months}.csv file which contains the path to the preprocessed images and the corresponding label for progression task (sMCI, pMCI).

Deep Learning approaches

Diagnosis Network proposed in the paper (Axial3D)

Training and test a model with 5-fold cross-validation

To train a model, run the train.py script. The configuration file used must is a YAML file: config.yaml.

python train.py

The script produces a results that is the average of the 5-fold cross-validation. The table below shows the results obtained on ADNI1 Complete 1Yr 1.5T by the proposed model and the other models that can be tested with the framework.

Networks AD vs. CN sMCI vs. pMCI
ACC SPE SEN MCC ACC SPE SEN MCC
Majority Voting (2D)(VGG16) 0.804 0.897 0.688 0.605 0.614 0.601 0.629 0.229
Attention Transformer (TransformerConv3D)(VGG16) 0.826 0.914 0.717 0.651 0.623 0.665 0.4873 0.238
AwareNet Diagnosis (AwareNet)(3D) 0.832 0.875 0.778 0.659 0.4841 0.774 0.258 0.039
Ours (Axial3D)(VGG16) 0.856 0.910 0.792 0.712 0.725 0.763 0.678 0.443
Attention-Guided Majority Voting (2D)(VGG16) 0.843 0.894 0.780 0.683 0.633 0.624 0.643 0.266
Majority Voting 3D (3D)(VGG16) 0.836 0.867 0.797 0.667 0.629 0.653 0.601 0.254

Note: The data splitting is performed by subjects to avoid data leakage problems. Make sure to adjust the configuration files according to your specific paths and requirements.

Attention-based Explainability

Attention XAI approach proposed in the paper with Axial3D

The attention_xai_analysis.py file generates distributions of attention weights for the three planes (axial, coronal, sagittal) starting from three different models that output attention weights slice distribution and produces a result on 5-fold cross-validation to validate consistency. It also produces an explainable MRI on top of the template image MRI. This explainable image is used to produce XAI metrics that help identify which regions are more important for discerning an AD patient from a healthy one.

To run the explainability analysis, execute the following command:

python attention_xai_analysis.py

Results:

Brain Region 𝑉𝑟 𝜇𝑟 𝜎𝑟 𝐴𝑚𝑎𝑥,𝑟 𝐴𝑚𝑖𝑛,𝑟 𝑃𝑟
Hippocampus left 1562 0.136 0.139 0.762 0.028 0.333
Hippocampus right 1426 0.126 0.133 0.783 0.028 0.304
Parahippocampal left 688 0.129 0.137 0.884 0.028 0.254
Parahippocampal right 534 0.129 0.148 1.000 0.028 0.197
Amygdala left 480 0.097 0.092 0.620 0.028 0.291
Amygdala right 427 0.095 0.087 0.569 0.028 0.259
Inferior Lateral Ventricle right 232 0.113 0.129 0.677 0.028 0.219
Inferior Lateral Ventricle left 212 0.106 0.105 0.589 0.028 0.200
Cerebellum Gray Matter left 208 0.035 0.005 0.052 0.028 0.003
Lateral Orbitofrontal left 194 0.033 0.004 0.045 0.028 0.013
Fusiform right 184 0.045 0.015 0.107 0.028 0.014
Lateral Orbitofrontal right 140 0.034 0.004 0.046 0.028 0.009
Cerebellum Gray Matter right 119 0.034 0.005 0.054 0.028 0.002
Fusiform left 88 0.040 0.010 0.070 0.028 0.007
Entorhinal left 16 0.034 0.003 0.041 0.030 0.005
Ventral Diencephalon left 6 0.029 0.001 0.031 0.028 0.001
Entorhinal right 2 0.033 0.003 0.036 0.031 0.001

GradCAM-based Explainability

The cam_xai_analysis.py file performs similar tasks using CAM maps produced with GradCAM-like methods.

To run the explainability analysis, execute the following command:

python cam_xai_analysis.py

Results:

Visualization of mean 3D GradCAM++ map of entire dataset overlapped to MNI152 template with Axial3D (VGG16) Visualization of mean 3D GradCAM++ map of entire dataset overlapped to MNI152 template with TransformerConv3D (VGG16)

Acknowledgement

License

This project is licensed. Please review the License file for more information.

Citation

If you find this work useful for your research, please 🌟 our project and cite our paper :

@article{lozupone2024axial,
  title={AXIAL: Attention-based eXplainability for Interpretable Alzheimer's Localized Diagnosis using 2D CNNs on 3D MRI brain scans},
  author={Lozupone, Gabriele and Bria, Alessandro and Fontanella, Francesco and De Stefano, Claudio},
  journal={arXiv preprint arXiv:2407.02418},
  year={2024}
}