nyuad-cai / CXR-ML-GZSL

30 stars 8 forks source link

Code for MLHC 2021 Multi-Label Generalized Zero Shot Learning for the Classification of Disease in Chest Radiographs

Table of contents

Background

Despite the success of deep neural networks in chest X-ray (CXR) diagnosis, supervised learning only allows the prediction of disease classes that were seen during training. At inference, these networks cannot predict an unseen disease class. Incorporating a new class requires the collection of labeled data, which is not a trivial task, especially for less frequently-occurring diseases. As a result, it becomes inconceivable to build a model that can diagnose all possible disease classes. This repo contains the pytorch implementation for our proposed network; multi-label generalized zero shot learning (CXR-ML-GZSL) that can simultaneously predict multiple seen and unseen diseases in CXR images. Given an input image, CXR-ML-GZSL learns the visual representations guided by the input's corresponding semantics extracted from a rich medical text corpus.

Overview of the CXR-ML-GZSL network

The components of the network consist of (i) a trainable visual encoder, (ii) a fixed semantic encoder, & (iii) a projection module to map the encoded features to a joint latent space. Our approach is end-to-end trainable and does not require offline training of the visual feature encoder.

Environment setup

git clone https://github.com/nyuad-cai/CXR-ML-GZSL.git
cd CXR-ML-GZSL
conda env create -f environment.yml
conda activate zsl

Dataset

We evaluated the proposed method on the NIH Chest X-ray dataset with a random split of 10 seen and 4 unseen classes. To train and evaluate the network, download the NIH chest x-ray dataset.

Model evaluation

Model training

Citation

If you use this code for your research, please consider citing:

@misc{hayat2021multilabel,
      title={Multi-Label Generalized Zero Shot Learning for the Classification of Disease in Chest Radiographs}, 
      author={Nasir Hayat and Hazem Lashen and Farah E. Shamout},
      year={2021},
      eprint={2107.06563},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

Results