victorca25 / traiNNer

traiNNer: Deep learning framework for image and video super-resolution, restoration and image-to-image translation, for training and testing.
Apache License 2.0
293 stars 39 forks source link

[Suggestion]: Relativistic GAN Type #41

Closed N0manDemo closed 3 years ago

N0manDemo commented 3 years ago

I was reading about the GAN types (Vanilla, LSGAN, and WGAN-GP) already included in BasicSR, and I found a new type that may bring a sizable performance increase to the discriminators used in upscaling methods like ESRGAN and PPON.

https://arxiv.org/abs/1807.00734

This paper outlines the idea behind a relativistic discriminator and showcases new variants of existing GANs that were created to use this approach. There is also source code available: https://www.github.com/AlexiaJM/RelativisticGAN

The one that stood out to me was RaLSGAN.

It performs better than the other variants in most tests involving generating images that are 128x128 or less. When it comes to SGAN (Standard GAN), it outperforms this variant by a large margin.

Interested to hear your thoughts on this,

N0man

victorca25 commented 3 years ago

Hello again!

Currently the behavior of the GANs is relativistic. Basically, instead of calculating the labels as being only Real or Fake, it results in a measurement of "realness" that goes from -1 (fake label) to +1 (real label) when using the VGG-like discriminators.

During training you can check the D_real and D_fake outputs in the losses, these are samples of the discriminator label results. If using patchgan or multiscale patchgan, these might go over the range, since they calculate patches of images instead of a single value.

Let me know if you find out more interesting stuff or notice something in the current code, it's great if we find more techniques to improve results!

victorca25 commented 3 years ago

If interested, you can check the relativistic GAN formulation here: https://github.com/victorca25/BasicSR/blob/14aced7d1049a283761c145f3cf300a94c6ac4b9/codes/models/losses.py#L312

And compare with the repository where the operations match. It's a bit complex to see straight away because I have multiple different cases, but I can probably add some comments in the code for reference.

N0manDemo commented 3 years ago

That's great that BasicSR already has relativistic GAN support!

The only additional enhancement I can think of related to RaLSGAN is a vgg128 discriminator with spectrum normalization and feature extraction.

https://ieeexplore.ieee.org/document/9220103

According to this paper, spectrum normalization improves relativistic LSGAN by 5% (8% improvement over ESRGAN instead of only 3%).

BasicSR already has a vgg128 discriminator with spectrum normalization, but it does not have feature extraction, which is useful when using feature loss, especially with PPON.

Also, if the discriminator has the option of specifying convtype (using B.conv_block), it would allow the selection of partialconv, which could improve the discriminator's performance by an additional 1-2%.

victorca25 commented 3 years ago

Some of the discriminators already support using spectral norm.

For example, the discriminator_vgg_fea (that also can use feature maps to calculate "feature loss", which other projects later called feature matching) has the "spectral_norm" flag: https://github.com/victorca25/BasicSR/blob/14aced7d1049a283761c145f3cf300a94c6ac4b9/codes/models/modules/architectures/discriminators.py#L531 while patchgan has the equivalent "use_spectral_norm": https://github.com/victorca25/BasicSR/blob/14aced7d1049a283761c145f3cf300a94c6ac4b9/codes/models/modules/architectures/discriminators.py#L600

This version of the discriminator_vgg_fea will automatically adapt to the image crop size (128, 256, etc), while patchgan automatically work on patches of 70x70 by default.

Neither of the two flags appear in the example configuration template yaml or json, but both are fully set up to work by just adding the flags to the file:

https://github.com/victorca25/BasicSR/blob/14aced7d1049a283761c145f3cf300a94c6ac4b9/codes/models/networks.py#L322

https://github.com/victorca25/BasicSR/blob/14aced7d1049a283761c145f3cf300a94c6ac4b9/codes/models/networks.py#L332

I haven't been able to do a full ablation study with spectral norm, but initially didn't notice much of a difference on the results, but it could be that I didn't train for long enough.

If you could test it out, it would be very valuable feedback to see if I should expose the flags and add it to multiscale patchgan as well!

N0manDemo commented 3 years ago

Okay, I'll train a couple ESRGAN models with LSGAN and feature extraction, one without spectral normalization and show you the results.

I'll work on it this week.

N0manDemo commented 3 years ago

Here are the results from the training:

https://drive.google.com/file/d/1qLymsiDAkQy5LJgSJYjn-hwjmmUrT0OW/view?usp=sharing

I used the Flickr2k dataset with val being crops from the image files with the largest file size.

It seems that there might be a bug with spectral normalization. If you look at the discriminator results for SN in tensorboard, they have no negative values which could mean the discriminator is no longer relativistic.

What do you think?

victorca25 commented 3 years ago

Hello! I haven't been able to visualize the TBs yet, but I'll check the spectral normalization bits. The logic itself doesn't change with or without SN, so it initially leads me to believe that other parameters could need fine-tuning due to the change with the SN. If all discriminator results come out as positive, that means that all images are being labelled as "real", so it's not being able to tell fake from real.

Another option to the fine-tuning could be to make the discriminator focus on just the high-frequency parts of the images (enabling the "Frequency Separator" part of the configuration file) or making the generator work harder with batch augmentations ("Batch (Mixup) augmentations" in the config file).

Other than that, how do the image results compare with and without using SN?

Thanks for your tests, they help a lot!

N0manDemo commented 3 years ago

I couldn't tell a difference visually between sn and no_sn even using ImageMagick to show me where the differences were. https://drive.google.com/drive/folders/1VZ5xe4J8r5_2MA7v1wJvzmycaZTZK8Px?usp=sharing

The diffs are in diff.zip. Red pixels mean at least a 3% difference in pixel color.

Thank you for taking a look at the spectral normalization bits. If you find something please let me know.

victorca25 commented 3 years ago

I'll close the ticket, but I'll update with a few comments, with changes added in more recent versions of the code: