yyfang678 / DMIB

0 stars 0 forks source link

Overview

We include our Dynamic Multimodal Information Bottlneck (DMIB) codebase for the following:

Quick start

Prognosis Task for ITAC

cd prognosis_tasks
python train_fuse_covid.py --use_fuse_clinical_axial --use_clinical --preprocess_clinical=2 --model_name="proposed_crossmodal" --use_IB --expname="fuse_crossmodal"

Rundown of the arguments:

Prognosis Task for iCTCF

First run the following scripts to download and preprocess the iCTCF dataset (2D montage generation) or use our prepared iCTCF montages at https://drive.google.com/file/d/1VV9KFMyR3-jRZRBbxJqAAunbKtoD8EgP/view?usp=sharing

cd dataset/iCTCF
python no1_patient_enrollment.py
python no2_slice_to_3D.py
python no3_generate_montage.py

After the dataset preprocessing has been done, train the DMIB model with

cd prognosis_tasks
python train_fuse_covid.py --use_fuse_clinical_axial --use_clinical --preprocess_clinical=2 --model_name="proposed_crossmodal" --clinical_continuous="Age,Body temperature,MCHC,MCH,MCV,HCT,HGB,RBC,PDW,PLCT,MPV,PLT,BA,EO,MO,LY,NE,BAP,EOP,MOP,LYP,NEP,WBC,PLCR,RDWSD,RDWCV,ESR,CRP,PCT,ALG,ALB,ALP,ALT,AST,BUN,CA,CL,CO2,CREA,GGT,GLB,K,MG,Na,PHOS,TBIL,TP,URIC,CHOL,CK,HDLC,LDH,TG,AnG,DBIL,GLU,LDLC,OSM,PA,TBA,HBDH,CysC,LAP,5NT,HC,SAA,SdLDL,CD3+,CD4+,CD8+,BC,NKC,CD4/CD8,IL-2,IL-4,IL-6,IL-10,TNF,IFN" --clinical_category="Gender,Underlying diseases" --patient_died_ct_csv="../dataset/iCTCF/patients_enrol_list/enrolled_1.csv" --patient_survived_ct_csv="../dataset/iCTCF/patients_enrol_list/enrolled_0.csv" --patients_info_csv="../dataset/iCTCF/patients_enrol_list/enrolled_all.csv" --data_clinical="../dataset/iCTCF/patients_enrol_list/enrolled_all.csv" --datapath_train="../dataset/iCTCF/2D_montage/" --datapath_test="../dataset/iCTCF/2D_montage/" --use_IB --expname="iCTCF_all"

Prognosis Task Comparisons

Diagnosis Task for BRCA, ROSMAP

cd diagnosis_tasks
python train_fuse.py --dataset BRCA
python train_fuse.py --dataset ROSMAP

Codebase Structure

└── dataset
    └── BRCA
    └── ROSMAP
    └── ITAC (not included at the moment)
    └── iCTCF
└── diagnosis_tasks
    ├── model.py (includes DMIB + ablation models)
    ├── train_fuse.py (script for DMIB training and testing)
    ├── trainer.py (training utils)
└── prognosis_tasks
    └── loss
        ├── loss_auc.py (differentiable ROC AUC loss)
        ├── loss_conf.py (dynamic confidence loss)
        ├── loss_focal.py (focal loss)
        ├── loss_IB.py (our DMIB loss)
        ├── loss_sinkhorn.py (entropy-regularized optimal transport aka sinkhorn loss)
    └── model
        ├── clinical_only.py (clinical only model)
        ├── fusion_attention.py (attention-based fusion model)
        ├── fusion_concat.py (focal loss)
        ├── fusion_dynamic.py (dynamic fusion model)
        ├── fusion_crossmodal.py (our proposed DMIB fusion model)
        ├── fusion_transformer.py (transformer-based fusion model)
    └── utils
        ├── data.py (utils for preprocessing CT images)
        ├── img_process.py (utils for processing CT images)
        ├── trainer.py (training utils)
    ├── select_dataloader.py (dataloader for CT image only / clinical data only / fused)
    ├── select_model.py (initialize models according to experimental configurations)
    ├── select_optimizer.py (select optimizer according to experimental configurations)
    ├── select_parameters.py (experimental configurations)
    ├── train_fuse_covid.py (script for DMIB training and testing)
    ├── train_fuse_covid_ML.py (script for DMIB training and testing with ML methods, e.g. XGBoost, Random Forests, SVMs)