JustinhoCHN / SRGAN_Wasserstein

Apply Waseerstein GAN into SRGAN, a deep learning super resolution model
420 stars 102 forks source link

SRGAN_Wasserstein

Applying Waseerstein GAN to SRGAN, a GAN based super resolution algorithm.

This repo was forked from @zsdonghao 's tensorlayer/srgan repo, based on this original repo, I changed some code to apply wasserstein loss, making the training procedure more stable, thanks @zsdonghao again, for his great reimplementation.

SRGAN Architecture

TensorFlow Implementation of "Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network"

Wasserstein GAN

When the SRGAN was first proposed in 2016, we haven't had Wasserstein GAN(2017) yet, WGAN using wasserstein distance to measure the disturibution difference between different data set. As for the original GAN training, we don't know when to stop training the discriminator or the generator, to get a nice result. But when using the wasserstein loss, as the loss decreasing, the result will be better. So we are going to use the WGAN and we are not going to explain the math detail of WGAN here, but to give the following steps to apply WGAN.

These above steps was given by an excellent article[4], the arthor explained the WGAN in a very straightforward way, it was written in Chinese.

Loss curve and Result

Prepare Data and Pre-trained VGG

Run

We run this script under TensorFlow 1.4 and the TensorLayer 1.8.0+.

pip install tensorlayer==1.8.0
conda install tensorflow-gpu==1.3.0
pip install tensorflow-gpu==1.4.0
pip install easydict
config.TRAIN.img_path = "your_image_folder/"

I added the tensorboard callbacks to monitor the training procedure, please change the logdir to your folder.

config.VALID.logdir = 'your_tensorboard_folder'
python main.py
python main.py --mode=evaluate 

What's new?

Compare with the original version, I did the following changes:

  1. Adding WGAN, as described in Wasserstein GAN chapter.
  2. Adding tensorboard, to monitor the training procedure.
  3. Modified the last conv layer of 'SRGAN_g' in model.py (line 100), changing the kernel size from (1, 1) to (9, 9), as the paper proposed.

Reference

Author

License