This repository is a re-implementation of Explaining in Style: Training a GAN to explain a classifier in StyleSpace by Lang et al. (2021).
[Re] Explaining in Style: Training a GAN to explain a classifier in StyleSpace
Noah Van der Vleuten, Tadija Radusinović, Rick Akkerman, Meilina Reksoprodjo
Paper: https://rescience.github.io/bibliography/Vleuten_2022.html
Presented at NeurIPS 2022: https://nips.cc/virtual/2022/poster/56097
Contains a poster and a presentation/slides about our reproducibility efforts.
If you use this for research, please cite our paper:
@article{Vleuten:2022,
author = {van der Vleuten, Noah and Radusinović, Tadija and Akkerman, Rick and Reksoprodjo, Meilina},
title = {{[Re] Explaining in Style: Training a GAN to explain a classifier in StyleSpace}},
journal = {ReScience C},
year = {2022},
month = may,
volume = {8},
number = {2},
pages = {{#42}},
doi = {10.5281/zenodo.6574709},
url = {https://zenodo.org/record/6574709/files/article.pdf},
code_url = {https://github.com/NoahVl/Explaining-In-Style-Reproducibility-Study},
code_doi = {10.5281/zenodo.6512392},
code_swh = {swh:1:dir:04e11a55f476b115b40fd6af9d06ed70eb248535},
data_url = {},
data_doi = {},
review_url = {https://openreview.net/forum?id=SYUxyazQh0Y},
type = {Replication},
language = {Python},
domain = {ML Reproducibility Challenge 2021},
keywords = {rescience c, machine learning, deep learning, python, pytorch, explainable ai, xai, gan, stylegan2, stylex}
}
Running this notebook requires a CUDA-enabled graphics card. Installing the environment requires Conda.
stylex/all_results_notebook.ipynb
notebook.model_to_choose
to pick the dataset/model on which to show results. Default is 'plant'.The all_results_notebook.ipynb
works with pre-calculated latent vectors to generate results and run the experiments. If you want to generate the latent embeddings yourself, make use of the run_attfind_combined.ipynb
notebook (similarly, select the appropriate model_to_choose
). Note that you will have to download the datasets if you want to run AttFind (you can make use of the notebooks in the data folder).
Warning: The AttFind procedure is quite slow and may take over an hour depending on your hardware.
The StylEx framework consists of two parts, the "pretrained" classifier and the Encoder+GAN.
If you want to train a StylEx model on a new dataset we suggest you first train a new classifier and then provide it to the cli.py
file to train the StylEx model on this dataset with the new classifier in evaluation mode. If you use a Resnet/Mobilenet model you should only have to change the classifier_name parameter in the cli.py
file, or change it as a parameter using --classifier_name <mobilenet/resnet>
when you call the cli.py
file.
If you want to use a new classifier architecture you should add support for this in one of the stylex_train.py
files.
Natively we support the MobileNet V2 and ResNet architecture. Of the two options, ResNet seemed to give much better results on small images upscaled to 224px than MobileNet. The MobileNet classifier training code has been included, however to reiterate, it is advised to train a ResNet classifier when using small images. We have also observed that unfreezing the layers iteratively by editing a Python file is not that preferred.
Therefore we have also created and included a notebook that was used to train the ResNet-18 CelebA gender classifier, this classifier was then used to be explained by the StyleGAN model trained on the FFHQ dataset as per directions of the original paper. In the notebook it is also possible to train a MobileNet classifier.
The files of the user study, which has been discussed in the paper, have been included in this repository in the /all_user_studies
folder.
For more information, please look at the Github issues page at both the open and closed issues.
Our repository is based on the StyleGAN2 training code in PyTorch of the amazing repository of Github user lucidrains, stylegan2-pytorch. To their training code we added the StylEx training code.
The original TensorFlow notebook of the authors, including the AttFind algorithm from the authors has been translated to PyTorch. It has also been used to run their pretrained age StylEx model to extract experimental results. Both notebooks have been included.