JuheonYi / WESPE-TensorFlow

TensorFlow implementation of "Weakly Supervised Photo Enhancer for Digital Cameras"
53 stars 10 forks source link

Mode collapse #4

Open hcl14 opened 5 years ago

hcl14 commented 5 years ago

Hi,

I wanted to share the way to help with this problem:

Training seems to suffer from mode collapse (when trained wrong, enhanced image looks like negatively processed image). When Enhanced-Original PSNR seems to stay below 10 dB during training, stopped the training and started from beginning again.

Restarting did not help me. I was trying to do semi-supervised colorization with my images and some colorful photos collected from Google. What helped me, was adding a strong penalty on generated image histogram to make it not deviate much from original. Here I used grayscale which is a sum of channels for simplicity, but you could try with each color channel separately, etc. In WESPE_DIV2K.py:

# get histogram of a batch
    def histogram(self, img):
        # img is 0-1
        img1 = tf.image.rgb_to_grayscale(img) # rgb_weights = [0.2989, 0.5870, 0.1140]
        img1 = tf.squeeze(img1, -1) # remove last dim = 1
        values_range = tf.constant([0., 255.], dtype = tf.float32)
        histogram = tf.histogram_fixed_width(img1*255.0, values_range, nbins=25)
        return histogram

And then add histogram to content loss in build_generator_loss():

# content loss (vgg feature distance between original & reconstructed)
        original_vgg = net(self.vgg_dir, self.phone_patch * 255)
        reconstructed_vgg = net(self.vgg_dir, self.reconstructed_patch * 255)

        hist_phone = self.histogram(self.phone_patch)
        hist_rec = self.histogram(self.reconstructed_patch)
        histogram_loss = 1000*tf.losses.absolute_difference(hist_phone, hist_rec)

        self.content_loss = tf.reduce_mean(tf.square(original_vgg[self.content_layer] - reconstructed_vgg[self.content_layer]))  + histogram_loss

out

JuheonYi commented 5 years ago

Thanks for the constructive comment! Additional loss term to enforce the enhanced image to have different color distribution seems a nice idea.

I think it would be stronger if we could enforce the enhanced image to have "desirable" color distributions. Seems in the example you shared histogram loss clearly has an effect to change the color distribution on the enhanced image, but seems the overall color became somewhat red. Maybe adding additional discriminator to distinguish whether the color histogram of the image is "high quality" or "low quality" would help, too.

Thanks for suggesting a solution to enhance performance. Please share with us if you get any further progress!

Thanks.