lucidrains / lightweight-gan

Implementation of 'lightweight' GAN, proposed in ICLR 2021, in Pytorch. High resolution image generations that can be trained within a day or two
MIT License
1.63k stars 222 forks source link

How to improve quality #14

Open vibe007 opened 3 years ago

vibe007 commented 3 years ago

Thanks for putting this together! I'm having some success with this creating stylized artwork. I'm wondering what are the avenues to improve quality? It sounds like training for longer helps, along with adding attention. Is there a --network-capacity flag similar to your stylegan2 project? Should increasing the number of feature_maps fmap_max help? What about increasing the size of the latent_dim?

If we scale up to multi-GPU should we scale the learning rate a corresponding amount?

lucidrains commented 3 years ago

@vibe007 the author reported that changing the discriminator output size helped for artworks specifically https://github.com/lucidrains/lightweight-gan#discriminator-output-size other than that, you'll have to reach for the state of the art solution https://github.com/lucidrains/stylegan2-pytorch

lucidrains commented 3 years ago

@vibe007 what size are you training at and for how long? you should keep training until it collapses

vibe007 commented 3 years ago

So far I've tested resolutions 128 and 256, training for 150K iters (the default). I'll try training for longer since it definitely hasn't collapsed yet. The SOTA project you mentioned takes much much longer to train in my testing

I am using the recommended discriminator-output-size for artwork. Thanks!

lucidrains commented 3 years ago

@vibe007 yup, you can raise the number of training steps with --num-train-steps

lucidrains commented 3 years ago

@vibe007 with GANs, training isn't over if the game hasn't collapsed!

vibe007 commented 3 years ago

Thanks! Hope to update with some cool images soon...

druidOfCode commented 3 years ago

Hey, great tips! I have some similar questions. Currently, I'm giving the model some very cartoon styled artwork. I'm running it in Colab with a Tesla V100-SXM2-16GB GPU. I haven't let any run to the full 150k (usually you've released a new version and I'm just too curious to see if it offers better results), even so by 60k it's getting an FID of ~100 and looks like its properly converging. Which is fantastic to me! But still, I wonder if there's something I could do to make it even more awesome. Here are the settings I'm using, see anything glaring that I should change? Thanks in advance!

lightweight_gan --data /content/images --sle_spatial --attn-res-layers [32,64] --amp --disc-output-size 5 --models_dir "path" --results_dir "path" --calculate_fid_every 10000

I'm pretty sure I'm leaving processing power on the table here with my Colab notebook. P.S. Thanks for the auto aug_prob work, I always doubted I was using the right odds there 😊

lucidrains commented 3 years ago

@bckwalton good to hear! keep training! you can extend training by increasing the --num-train-steps I also added truncation to the input normal latents, which should help generation quality a bit. Do share a sample if you are allowed to do so :)

druidOfCode commented 3 years ago

Awesome! Right now everything looks a bit like soup since the last posting. 🤔 May have to change some settings. I'll report back with better results ⚡, all the images in the dataset are Cartoon characters in portrait shots and you can kinda see that here. Rather Soupy (This is step 69,000, generated with 0.12.2 [Truncation version])

tannisroot commented 3 years ago

In my results (20k iterations), I see mesh-like artifacts that are very noticeable, far more than in some of the demo samples. Is it a bug or an inherent flaw of such GAN that can't really be avoided? If so, is there a way to compensate for them?

lucidrains commented 3 years ago

@tannisroot you can try an antialiased version of this GAN by using the --antialias flag at the start of training. it'll be a bit slower though

lucidrains commented 3 years ago

@tannisroot otherwise, just train for longer, 20k is still quite early!

druidOfCode commented 3 years ago

Back with results. After toying with some additional attention layers here are the settings I landed on for Colab. (Warning to anyone following for Colab settings though, check your GPU version before using, they like to alternate between P100 and V100, these settings are for V100. If you get P100 you won't have enough VRAM to cram this many attention layers in).

lightweight_gan --data dataset_path --sle_spatial --attn-res-layers [32,64,128] --image-size 512 --amp --disc-output-size 5 --models_dir models_path --results_dir results_path --calculate_fid_every 10000

I also sanity checked my Dataset, most of its from random Booru's and filtered by tags. I tried my best to clean out as many mistags and culled about 300/5148 images (some just outright incorrect, others whole comic pages that just happened to have a portrait shot in one of the frames).

The results are markedly improved. It's only at 52,000 epochs right now (35%) but it's actually "understandable" the FID is reporting 108. Since I changed parts of the Dataset and added to the model I can't be certain what's responsible here, but here are my results regardless in case there's something to learn from here. 😀

52,000 It's converging already Link for Gif of Latent Space exploration: Imgur: Image and Gif (52,000 epochs)

70,000 Making progress Link for Gif of Latent Space exploration: Imgur: Image and Gif (70,000 epochs)

150,000 Needs more time Link for Gif of Latent Space exploration: Imgur: Image and Gif (150,000 epochs)

tannisroot commented 3 years ago

I've noticed that this GAN does not have an issue with augmentation where the augmentation leaks into the results, which is very common for https://github.com/lucidrains/stylegan2-pytorch. Is it known why it doesn't suffer from this? If so, can the state of art implementation be improved to avoid this problem? Also, can the auto-augmentation be backported as well? very nifty feature!

vibe007 commented 3 years ago

See these two papers on how to augment without leaking the augmentations into the generator:

Karras, Tero, et al. "Training generative adversarial networks with limited data." Advances in Neural Information Processing Systems 33 (2020). Zhao, Shengyu, et al. "Differentiable augmentation for data-efficient gan training." Advances in Neural Information Processing Systems 33 (2020).

I think the main idea is we want to augment the discriminator inputs as opposed to the generator inputs, and it appears the augmentations in https://github.com/lucidrains/stylegan2-pytorch do this correctly.

You can grab this file to use these augmentations in another project: https://github.com/lucidrains/stylegan2-pytorch/blob/master/stylegan2_pytorch/diff_augment.py

tannisroot commented 3 years ago

It's actually the opposite, lightweight-gan's augmentation doesn't leak, but stylegan2-pytorch's does, but thanks for the hint, I'll try swapping the augmenting bits!

woctezuma commented 3 years ago

The augmentation in the papers mentioned above should not leak. If you experience leaks with the implementation, maybe there is an issue with the implementation or the parameters.

See: https://github.com/lucidrains/stylegan2-pytorch#low-amounts-of-training-data

If one were to augment at a low enough probability, the augmentations will not 'leak' into the generations.

vibe007 commented 3 years ago

If the volume of training data is high, it's possible that data augmentations can hurt image quality. (as discussed in "Training generative adversarial networks with limited data.") It's also possible that you may not be training for enough time - see https://github.com/NVlabs/stylegan2 for expected training times for styleGAN2 (it's days to weeks).

iScriptLex commented 2 years ago

This model works surprisingly well even for small (less than 10 000 images), complex and highly variable datasets: scr_res And it took only about 15 hours on a single GeForce 1070Ti. Some tips to improve quality for cartoon images:

rickyars commented 2 years ago

@iScriptLex would you elaborate on that last bullet? How would I replace the GlobalContext block?

iScriptLex commented 2 years ago

@rickyars You should edit the file lightweight_gan.py. First, add original realization of SEBlock from FastGAN before Generator class definition (just paste this code before class Generator(nn.Module): line):

class Swish(nn.Module):
    def forward(self, feat):
        return feat * torch.sigmoid(feat)

# Self-excitation block realization from FastGAN
class SEBlock(nn.Module):
    def __init__(self, ch_in, ch_out):
        super().__init__()

        self.main = nn.Sequential(
            nn.AdaptiveAvgPool2d(4), 
            nn.Conv2d(ch_in, ch_out, 4, bias=False),
            Swish(),
            nn.Conv2d(ch_out, ch_out, 1, bias=False), 
            nn.Sigmoid() 
        )

    def forward(self, x):
        return self.main(x)

Then, in __init__ function of Generator class replace this code:

sle = GlobalContext(
    chan_in = chan_out,
    chan_out = sle_chan_out
)

with this: sle = SEBlock(ch_in = chan_out, ch_out = sle_chan_out)

rickyars commented 2 years ago

Thank you, @iScriptLex. I'll give this a try tonight. One more question for you. Whenever I run --sle_spatial it tells me:

ERROR: Could not consume arg: --sle_spatial

Any idea what I'm doing wrong?

woctezuma commented 2 years ago

Any idea what I'm doing wrong?