valeoai / STEEX

STEEX: Steering Counterfactual Explanations with Semantics
Other
18 stars 3 forks source link

Introduction

This is the official repository for the paper:

STEEX: Steering Counterfactual Explanations with Semantics, Paul Jacob, Éloi Zablocki, Hédi Ben-Younes, Mickaël Chen, Patrick Pérez, Matthieu Cord, in proceedings of ECCV 2022

This repository contains all the code and instructions to:

image Figure: Counterfactual explanations with STEEX

Installation

Clone this repo.

git clone https://github.com/valeoai/STEEX.git
cd STEEX/

This code requires PyTorch (1.8.1), python 3+, and cuda (11.1). Please install dependencies by

pip install -r requirements.txt

All checkpoints are provided in the release, please extract them. It should look like this:

/path/to/checkpoints/decision_densenet
/path/to/checkpoints/deeplabv3
/path/to/checkpoints/sean
/path/to/checkpoints/oracle_attribute
/path/to/checkpoints/vggface2_pretrainings_for_oracle

Dataset Preparation

This code uses four datasets: CelebA, CelebAMask-HQ, BDD100k, and the BDD-OIA extension. Please download and unzip the datasets from the respective websites.

To preprocess the masks for CelebAMask-HQ, please follow the intructions provided here.

To preprocess the CelebA dataset, like previous works, we crop and resize images to (128, 128) images. This is done with our script

python preprocess_celeba.py

At the end, the dataset directory should look like this.

CelebAMask-HQ:

image_train     = "/path/to/dataset/CelebAMask-HQ/train/images"
mask_train      = "/path/to/dataset/CelebAMask-HQ/train/labels"
image_val       = "/path/to/dataset/CelebAMask-HQ/test/images"
attributes_path = "/path/to/dataset/CelebAMask-HQ/CelebAMask-HQ-attribute-anno.txt"

CelebA:

image_train     = "/path/to/dataset/celeba_squared_128/img_squared128_celeba_train"
image_val       = "/path/to/dataset/celeba_squared_128/img_squared128_celeba_test"
attributes_path = "/path/to/dataset/celeba_squared_128/list_attr_celeba.txt"

BDD-OIA:

bddoia_data    = "/path/to/dataset/bdd-oia/lastframe/data"
image_train    = "/path/to/dataset/bdd-oia/lastframe/train_25k_images_actions.json"
decision_train = "/path/to/dataset/bdd-oia/lastframe/train_25k_images_reasons.json"
image_val      = "/path/to/dataset/bdd-oia/lastframe/val_25k_images_actions.json"
decision_val   = "/path/to/dataset/bdd-oia/lastframe/val_25k_images_reasons.json"

BDD100k:

images_train = "/path/to/dataset/BDD/bdd100k/seg/images/train"
mask_train   = "/path/to/dataset/BDD/bdd100k/seg/labels/train"
images_val   = "/path/to/dataset/BDD/bdd100k/seg/images/val"

Train decision models

For CelebA and CelebAMask-HQ, train the respective decision models (densenet) with the command below. Please, precise the data path <data_dir> and the checkpoint directory <checkpoints_dir>. To choose between CelabA and CelebAMask-HQ, simply toggle the commented code block on the top in the Args class.

python train_decision_model_celeba_or_celebamhq.py

To train the decision model on BDD-OIA, run the command below. Please precise the data path <data_dir> and the checkpoint directory <checkpoints_dir>.

python train_decision_model_bdd.py

At the end, the checkpoints must be stored with the following structure:

/path/to/checkpoints/decision_densenet/celeba
/path/to/checkpoints/decision_densenet/celebamaskhq
/path/to/checkpoints/decision_densenet/bdd

Train Semantic Segmentation Model

Train two semantic segmentation models, one for CelebaMask-HQ, one for BDD with the following command:

python train_deeplabv3.py

Simply specify the path to the checkpoints directory where the models will be saved (checkpoints_dir in the Args class) and toggle the commented block to switch from CelebAMask-HQ to BDD100K-seg.

At the end, the checkpoints must be stored with the following structure:

/path/to/checkpoints/deeplabv3/celebamaskhq
/path/to/checkpoints/deeplabv3/bdd

Generate Semantic Segmentation Masks on All Data

After the semantic segmentation models are trained, generate all masks for the validation splits of CelebaMask-HQ and BDD with the following script. Make sure to update the class Args at the top of the file, by specifying the dataroot path, the name of the segmentation network trained in the previous step segmentation_network_name, the path where the masks should be saved save_dir_masks.

python infer_masks.py

To generate semantic segmentation masks for CelebA, as this dataset has no ground-truth for semantic segmentation masks, first resize CelebA images to (256, 256), forward them in the DeepLabv3 network trained on CelebAMask-HQ, and resize down obtained masks back to the original (128, 128) resolution.

At the end, the masks should then be stored at:

Train SEAN AutoEncoder

For CelebA and BDD, train a semantic SEAN AutoEncoder with the official github repository: https://github.com/ZPdesu/SEAN by specifying the data path to the image and to the semantic segmentation masks generated in the step above. For CelebAMask-HQ, we used the original SEAN weights provided by the authors in their repository: CelebA-HQ_pretrained

At the end, the checkpoints must be stored with the following structure:

/path/to/checkpoints/sean/celeba
/path/to/checkpoints/sean/celebamaskhq
/path/to/checkpoints/sean/bdd

Generating Counterfactual Explanations

Run the following commmand

python generate_counterfactuals.py --dataset_name <bdd_or_celeba_or_celebamhq> --checkpoints_dir <path_to_checkpoints_dir> --dataroot <path_to_dataroot> --name_exp <the_experience_name>

The checkpoints_dir contains the the checkpoints for the decision model as well as the sean encoder-decoder. The dataroot is the folder where the data is stored.

The option --target_attribute specifies the class of interest. Keep the default value (0) for Move forward vs. Stop on BDD. For the face datasets (CelebA and CelebAMask-HQ), --target_attribute 1 is for the Smile classification and --target_attribute 2 is for the Young classification.

The general setting targets all regions. To generate region-targeted counterfactual explanations, you can specify the regions (separated by commas) with the option --specified_regions. See at the top of the file generate_counterfactual.py the list of possible regions. For example, for CelebA or CelebAMask-HQ: --specified_regions skin,nose,glasses. For BDD: --specified_regions road,car,motorcycle.

The whole list of options is available in the options/ folder.

Evaluation

  1. Success Rate

To get the succes rate, run the following command:

python compute_success_rate.py --expe </path/to/results_dir/name_exp>
  1. Fréchet Inception Distance (FID)

Run the following bash command to get the FID score:

python compute_fid.py --expe </path/to/results_dir/name_exp>

The processed folders are the query images resized to the correct dimension. They can be obtained (and then moved to the path indicated in the compute_fid.py file) by using the flag --save_query_image True in the script generate_counterfactuals.py.

  1. Face Verification Accuracy (FVA)

Run the following bash command to get the FVA score:

python compute_fva.py --expe </path/to/results_dir/name_exp>
  1. Mean Number of Attribute Changes (MNAC)

4.1. The first time, you need to train the VGGFace2 oracle model on the 40 attributes for both CelebA and CelebAMask-HQ. The training is done by the file train_vggface2_oracle.py. Train both an oracle for CelebA with --celeba_or_celebamhq celeba and CelebAMask-HQ with --celeba_or_celebamhq celebamhq.

python train_vggface2_oracle.py --dataroot </path/to/dataroot> --checkpoint_dir </path/to/checkpoints/dir>

At the end, the checkpoints must be stored with the following structure:

/path/to/checkpoints/oracle_attribute/celeba
/path/to/checkpoints/oracle_attribute/celebamaskhq

4.2. When the oracle model is trained and saved, simply put the oracle name at the top of the compute_mnac.py file (placeholder <oracle_name>) Run the following bash command to get the MNAC score.

python compute_mnac.py --expe </path/to/results_dir/name_exp> --checkpoint_dir </path/to/checkpoints/dir>

Disclaimer

There might be some bugs or errors. Feel free to open an issue and/or contribute to improve the repo.

Citation

If the code helped you for your research, please consider citing

@inproceedings{jacob2022steex,
  author    = {Paul Jacob and
               {\'{E}}loi Zablocki and
               Hedi Ben{-}Younes and
               Micka{\"{e}}l Chen and
               Patrick P{\'{e}}rez and
               Matthieu Cord},
  title     = {{STEEX:} Steering Counterfactual Explanations with Semantics},
  booktitle = {ECCV},
  year      = {2022},
}