biomedia-mira / causal-gen

(ICML 2023) High Fidelity Image Counterfactuals with Probabilistic Causal Models
https://arxiv.org/abs/2306.15764
MIT License
55 stars 7 forks source link
causality computer-vision counterfactual generative-model machine-learning medical-imaging neural-network python pytorch uncertainty

Causal Generative Modelling: Image Counterfactuals

:hugs:Huggingface demo here!:hugs:

Code for the ICML 2023 paper:

High Fidelity Image Counterfactuals with Probabilistic Causal Models\ Fabio De Sousa Ribeiro1, Tian Xia1, Miguel Monteiro1, Nick Pawlowski2, Ben Glocker1\ 1Imperial College London, 2Microsoft Research Cambridge, UK

BibTeX:

@InProceedings{pmlr-v202-de-sousa-ribeiro23a,
  title={High Fidelity Image Counterfactuals with Probabilistic Causal Models},
  author={De Sousa Ribeiro, Fabio and Xia, Tian and Monteiro, Miguel and Pawlowski, Nick and Glocker, Ben},
  booktitle={Proceedings of the 40th International Conference on Machine Learning},
  pages={7390--7425},
  year={2023},
  volume={202},
  series={Proceedings of Machine Learning Research},
  month={23--29 Jul},
  url={https://proceedings.mlr.press/v202/de-sousa-ribeiro23a.html}
}

Example Results:

Project Structure:

📦src                                  # main source code directory
 ┣ 📂pgm                               # graphical models for all SCM mechanisms except the image's
 ┃ ┣ 📜dscm.py                         # deep structural causal model Pytorch module
 ┃ ┣ 📜flow_pgm.py                     # Flow mechanisms in Pyro
 ┃ ┣ 📜layers.py                       # utility modules/layers
 ┃ ┣ 📜resnet.py                       # resnet model definition
 ┃ ┣ 📜run.sh                          # example launch script for counterfactual training (slurm)
 ┃ ┣ 📜train_cf.py                     # counterfactual training code
 ┃ ┣ 📜train_pgm.py                    # SCM mechanisms training code (Pyro)
 ┃ ┗ 📜utils_pgm.py                    # graphical model utilities
 ┣ 📜datasets.py                       # dataset definitions
 ┣ 📜dmol.py                           # discretized mixture of logistics likelihood
 ┣ 📜hps.py                            # hyperparameters for all datasets
 ┣ 📜main.py                           # main file
 ┣ 📜run_local.sh                      # example launch script for HVAE causal mechanism training
 ┣ 📜run_slurm.sh                      # same as above but for slurm jobs
 ┣ 📜simple_vae.py                     # single stochastic layer VAE
 ┣ 📜trainer.py                        # training code for image x's causal mechanism
 ┣ 📜train_setup.py                    # training helpers
 ┣ 📜utils.py                          # utilities for training/plotting
 ┗ 📜vae.py                            # HVAE definition; exogenous prior and latent mediator models 

Overview

Our deep structural causal models (SCMs) were designed to be modular: in all instances, the causal mechanism for the structured variable (i.e. image $\mathbf{x}$) is trained separately from the other mechanisms in the associated causal graph. This enables direct and fair comparisons of different causal mechanisms for $\mathbf{x}$ by holding the remaining mechanisms fixed when making comparisons.

We use the universal probabilistic programming language (PPL) Pyro for the following:

  1. Modelling and training all SCM mechanisms except for the image $\mathbf{x}$'s, see code in src/pgm;
  2. The counterfactual inference engine, see src/pgm/flow_pgm.py;
  3. Proposed constrained counterfactual training technique, see src/pgm/train_cf.py.

Pyro enables flexible and expressive deep probabilistic modeling, for more details refer to the official site.

Our HVAE-based causal mechanisms (src/vae.py) are trained outside of Pyro using Pytorch, and all trained mechanisms are subsequently merged into a single Pytorch module to create a DSCM. See src/pgm/dscm.py for an example.

Requirements

To run the code you will need to install the requirements listed in the requirements.txt file. E.g. from inside your env of choice run:

pip install -r requirements.txt

Data

For ease of use, we provide the Morpho-MNIST dataset we used in datasets/morphomnist. For more details on the associated SCM and data-generating process see the source code here and the original DSCM paper here.

The Colour-MNIST dataset we used was generated according to this paper.

Unfortunately, we are unable to share the UK Biobank brain data or the MIMIC-CXR chest x-ray data.

If you're interested in gaining access, we recommend you check out the specific documents provided. These resources contain all the necessary details regarding the application process, as well as the eligibility criteria. Application and eligibility criteria for gaining access are detailed here and here respectively.

Run

To launch (local) training of the HVAE mechanism simply run the following script from inside the src directory:

bash run_local.sh your_experiment_name

To run in the background you can append nohup to the command: bash run_local.sh your_experiment_name nohup. Adjust the run_command inside the script as needed. Hyperparameters can be found in src/hps.py. If using Slurm Workload Manager, adjust src/run_slurm.sh as needed and launch as bash run_slurm.sh.

Example (loose) steps to add your own dataset and associated SCM:

  1. Add dataset class definition to src/datasets and setup the dataloader in src/train_setup.py
  2. Add associated causal graph and mechanism definitions in src/pgm/flow_pgm.py
  3. Adjust HVAE hyperparameters needed for your dataset (input resolution, architecture, etc) in src/hps.py
  4. Train the HVAE mechanism as above, and train all other mechanisms (separately) using src/pgm/train_pgm.py

Note: src/pgm/train_cf.py implements the optional counterfactual training/fine-tuning procedure outlined in Section 3.4 of the paper. This step may not be necessary if the model already performs well enough at counterfactual inference.

If you'd like to make the HVAE more lightweight you can try reducing the number of blocks at each resolution and reducing the block width (hyperparameters enc_arch, dec_arch, and width found in src/hps.py). The block version == "light" in src/vae.py also uses half as much VRAM.

To resume training from a checkpoint simply adjust the argument: --resume=/path/to/your/checkpoint.pt.