lucidrains / lightweight-gan

Implementation of 'lightweight' GAN, proposed in ICLR 2021, in Pytorch. High resolution image generations that can be trained within a day or two
MIT License
1.64k stars 222 forks source link

question about generated images #115

Open rickyars opened 2 years ago

rickyars commented 2 years ago

I've run the same data through two GANs: lightweight-gan and Playform. The results are similar but different. I'm wondering if anyone has suggestions for cleaning up the outputs from the lightweight-gan. they seem to be a bit muddier/noisy. See attachments 170 Untitled Project_21_0043 .

iScriptLex commented 2 years ago

These things are called checkerboard artifacts. They are caused by upsampling data between layers of network. These artifacts are highly visible in the early stages of training, but slowly reduces with every step and becomes barely noticeable after ~80000 iterations. But if you want to get rid of it completely, even in the early stages, you can do it in several ways:

  1. Use --antialias key. This significantly increases training time, but entirely removes these artifacts. You can make training faster by writing your own smoothing function because the default one is unoptimized and slow.
  2. Change upsampling type to bilinear (this also effectively removes such artifacts and works faster than --antialias method). To do so, just edit the file lightweight_gan.py and in upsample function change this line: return nn.Upsample(scale_factor = scale_factor) to this: return nn.Upsample(scale_factor = scale_factor, mode='bilinear', align_corners=False)
  3. Rewrite the code of the main layers of Generator and Discriminator so that the size of the convolution kernel is a multiple of the scaling factor (i.e. 2).
rickyars commented 2 years ago

@iScriptLex What timing. Earlier today I found the --antialias flag. (1) Training takes much, much longer and (2) it does fix my problem! Now I can't use --show-progress because of an issue with some dict. I've opened an issue with the error.

rickyars commented 2 years ago

Option 1: The intermediate results from --antialias look amazing, but it's very slow. After 4 hours, it also doesn't look anywhere close to converging, but the result is trippy: 43

rickyars commented 2 years ago

@iScriptLex sorry to bother you again. Looking at the code, the generator seems pretty easy to modify. It's currently at 3, so I'll bump it up to 4. However, the discriminator is is a lot more complicated. Which lines would I focus on if I wanted to modify the "main layers"? Thank you!

lucidrains commented 2 years ago

@rickyars let me know how well this works for you! https://github.com/lucidrains/lightweight-gan/commit/bfd0e8ad4ee9e8ef4c4a2a7af5a5e134afd8c3bd (separate branch for now)

rickyars commented 2 years ago

@lucidrains wow! you are fast! i need to sleep but i'll test it out tomorrow. thank you!

iScriptLex commented 2 years ago

FastGAN suppresses these artifacts by using interesting method - noise injection. I assume the same method will work for LightweightGAN. First, add this code before Generator class:

class NoiseInjection(nn.Module):
    def __init__(self):
        super().__init__()

        self.weight = nn.Parameter(torch.zeros(1), requires_grad=True)

    def forward(self, feat, noise=None):
        if noise is None:
            batch, _, height, width = feat.shape
            noise = torch.randn(batch, 1, height, width).to(feat.device)

        return feat + self.weight * noise

Then, change Generator layer structure from this:

upsample(),
Blur(),
nn.Conv2d(chan_in, chan_out * 2, 3, padding = 1),
norm_class(chan_out * 2),
nn.GLU(dim = 1)

to this (just add NoiseInjection before norm_class):

upsample(),
Blur(),
nn.Conv2d(chan_in, chan_out * 2, 3, padding = 1),
NoiseInjection(),
norm_class(chan_out * 2),
nn.GLU(dim = 1)
rickyars commented 2 years ago

Again, @iScriptLex, thank you for the quick reply! Do you ever sleep? I'll give this a try after my run.

lucidrains commented 2 years ago

@iScriptLex thank you! :pray:

@rickyars added it for you https://github.com/lucidrains/lightweight-gan/commit/30b2e05031f11af2cb74feb99e9e2a7f5e625175

lucidrains commented 2 years ago

2

adding noise really helps with texture formation for my dataset (probably not for the smooth drawings of @rickyars ' dataset) i'm using bilinear upsample + noise injection. i knew i should have carried over a few more elements from Stylegan2!

@iScriptLex thank you, anonymous stranger :)

lucidrains commented 2 years ago

while we are focused on this repository, i'll improve the efficient attention greatly :) i've grown so much (in terms of my knowledge of attention) since leaving GANs behind

lucidrains commented 2 years ago

@rickyars cool, let me know how your experiment goes :) the bilinear + noise works very well for me, mainly because my dataset (flowers) has a lot of texture and noisy green nature stuff in the background, and the network learns to utilize the noise well. your dataset seems to be mostly smooth vector graphics, so you probably won't see the benefit

lucidrains commented 2 years ago

6-ema

lucidrains commented 2 years ago

@rickyars what is that?? pacman ghosts or atari space invaders?

rickyars commented 2 years ago

@lucidrains I'm trying to make a bitGAN: https://bitgans.com/ unnamed

rickyars commented 2 years ago

@lucidrains Heads up. I saw you pushed some changes to attention, but I'm not getting as good of results with version 0.21.1.

lucidrains commented 2 years ago

@rickyars darn, do you want to try 0.21.2? I can always revert it back to the old linear attention if it still doesn't do well

rickyars commented 2 years ago

@lucidrains I'll try 0.21.2 and see how it looks. Otherwise, I might go back to 0.20.8.

Two questions for you:

  1. Any reason why I'm suddenly seeing the cutout leak through in early training?
  2. How are the initial generator images created? Right now it looks like it's all black. The other GAN I've been using, has a very colorful first guess and then starts to converge rather quickly. I'm hoping to poke around that section of the code. Untitled Project_01_0043
lucidrains commented 2 years ago

it has been a while since i've done any GAN training, nor am i caught up with literature, so take what i say with a grain of salt

iirc, the augmentations were not supposed to leak if you kept it at a low enough probability during training. however, i have vague memories of someone else complaining about cutout augmentation leaking, so perhaps the paper's conclusions weren't true, which happens. you can customize the augmentations with a command-line flag (remove the cutout augmentation altogether)

the initial black images is likely due to the initialization of the network. i don't think the convergence rate should differ that much. lightweight-gan is among the fastest GAN out there for training.

rickyars commented 2 years ago

i wasn't seeing any issue with the cutouts until version 0.21, which is why i was asking if something else changed when you pushed through the new noise injection. i looked at the commits and didn't see anything, so i'm baffled!

i'm running 0.21.2 with no cutout and i'll let you know how it goes. my training set is randomly generated so i can also create more images to feed the GAN.

btw, the reason i'm using lightweight GAN is because of it's speed. as you can see, i don't need a lot of detail in my outputs and i'm actually more interested in watching the GAN learn and quick convergence.

lucidrains commented 2 years ago

@rickyars i think the cutout leaking happens on and off, regardless of version

yes, watching GANs learn right in front of you is a magical experience :) ok keep me posted about your results. my flowers seem to be blooming inside the machine quite well, but i'll let it train until 10k steps to make sure the new attention didn't break anything

rickyars commented 2 years ago

@lucidrains Thanks again for the help today.

Btw, I love this line: "my flowers seem to be blooming inside the machine quite well"

rickyars commented 2 years ago

@lucidrains my Google colab instance shut down last night. I restarted this morning and I'm getting very, very different results today. Normally I would be getting something that looks like my end result by this iteration. Now I'm getting: 20

Did anti-alias somehow get turned on by default? This is the exact same code from yesterday just a fresh pip install lightweight-gan

rickyars commented 2 years ago

This is with !pip install lightweight-gan==0.20.8 at the same iteration: 20_0 20 8

lucidrains commented 2 years ago

Ohh I turned on bilinear upsampling by default, which would produce what you are seeing, at least, early in training

rickyars commented 2 years ago

Wow. I'm an idiot. I was making code changes but not actually running that code. I will delete my test runs from the thread so as not to confuse other people. Sorry!

rickyars commented 2 years ago

@lucidrains i feel like such an idiot. i've gone back to 0.21.0 to test just the noise inject by turning off bilinear filters (i think i know how to do this now). do you know if 0.21.2 has any runtime impacts? i don't know why but my iterations seem to be taking longer.

lucidrains commented 2 years ago

@rickyars yea, there is a slight increase in run-time, mainly due to the attention modifications, and i may keep it (you can always turn it off by setting --attn-res-layers [])

rickyars commented 2 years ago

@lucidrains now that i know what i'm doing, i started a test with the power of two convolutional window. it's looking pretty good! i'll probably add the noise and use that for my final set of run. will post a pic in a bit.

rickyars commented 2 years ago

@lucidrains i'm actually really happy with these results so far: 25

lucidrains commented 2 years ago

@rickyars very nice! maybe i'll add the power of 2 kernels :) it'll be slightly slower yet again, but GPUs keep getting more powerful by the year, so who cares :)

lucidrains commented 2 years ago

@rickyars you turned off the bilinear upsampling?

rickyars commented 2 years ago

@rickyars you turned off the bilinear upsampling?

Yes. This output was generated on bfd0e8ad4ee9e8ef4c4a2a7af5a5e134afd8c3bd.

Bilinear is off. It seems to to really hurt my use case.

lucidrains commented 2 years ago

@rickyars ok, ill turn it off for now :)

rickyars commented 2 years ago

@lucidrains thank you!