teaching machines to lie
This repository contains a tensorflow/keras implementation of the elegant object removal model described in Adversarial Scene Editing: Automatic Object Removal from Weak Supervision by Shetty, Fritz, and Schiele.
We've been testing using Python 3.6 and TensorFlow 1.14.
On the Kaggle "Ships in Satellite Imagery" dataset:
So far we've only tested single-class object removal. You'll need two lists of filepaths- one to image patches containing objects and one without. All patches should be prepared to the same size.
Shetty et al's model has several components; decepticon
expects a keras
Model object for each. You can reproduce the models from the paper or substitute your own so long as they have the correct inputs and outputs:
Component | Description | Input Shape | Output Shape | Code |
---|---|---|---|---|
mask generator | fully-convolutional network that generates a mask from an input image | (None, None, None, 3) |
(None, None, None, 1) |
decepticon.build_mask_generator() |
classifier | standard convolutional classifier that maps an image to a probability over categories | (None, None, None, 3) |
(None, num_classes+1) |
decepticon.build_classifier() |
inpainter | fully-convolutional network that inputs a partially-masked image (with a 4th channel containing the mask) and attempts to generate the original unmasked version (like a denoising autoencoder or context encoder) | (None, None, None, 4) |
(None, None, None, 3) |
decepticon.build_inpainter() |
discriminator | fully-convolutional network that inputs an image and makes a pixel-wise assessment about whether the image is real or fake | (None, None, None, 3) |
(None, None, None, 1) |
decepticon.build_discriminator() |
mask discriminator | convolutional classifier that tries to classify mask generator outputs as real or fake | (None, None, None, 1) | (None, 1) | decepticon.build_mask_discriminator() |
If you're training on a consumer GPU you may run into memory limitations using the models from the paper and a reasonable batch size- if you pass the keyword argument downsample=n
to any of the above functions, the number of filters in every hidden convolutional layer will be reduced by a factor of n
.
Before starting, initialize your own custom models or the ones from the paper:
classifier = decepticon.build_classifier()
inpainter = decepticon.build_inpainter()
maskgen = decepticon.build_mask_generator()
disc = decepticon.build_discriminator()
Trainer
objectThe decepticon.Trainer
class has convenience code for all the training steps- instantiate it with the initialized models, lists of paths to positive and negative image patches, hyperparameters, and a path to a log directory:
trainer = decepticon.Trainer(maskgen, classifier,
inpainter, disc,
posfiles, negfiles,
lr=1e-5, disc_weight=10,
logdir="/path/to/log/directory/",
batch_size=32,
exponential_loss_weight=1
)
The image classifier is trained on randomly-masked images. You can use decepticon.classifier_training_dataset()
to build a randomly-masked tf.data.Dataset
object that you can use with the standard Keras Model.fit()
API, or use trainer.pretrain_classifier()
:
trainer.pretrain_classifier(epochs=10)
Finally, run the alternating-epoch (mask generator vs inpainter) and alternating-batch (inpainter and discriminator) training:
trainer.fit(10)
You can also run the different training phases independently with trainer.fit_mask_generator()
and trainer.fit_inpainter()
.
If a log directory is specified, models and TensorBoard logs will be saved there.
All the terms in the loss function will be recorded as scalars:
Histograms will record classification and discriminator probabilities on reconstructed examples, to visualize how well the system is fooling them:
Images will record examples of raw images, mask generator and inpainter outputs for them, and the reconstructed image:
The end-to-end object removal network is stored as a Keras model in trainer.full_model
; run inference using the normal predict
interface.
This package was created with Cookiecutter and the audreyr/cookiecutter-pypackage project template.