lucidrains / imagen-pytorch

Implementation of Imagen, Google's Text-to-Image Neural Network, in Pytorch
MIT License
8.07k stars 767 forks source link

Noise in output #50

Closed deepglugs closed 2 years ago

deepglugs commented 2 years ago

Using the latest master, I'm noticing big improvement from 0.0.60. The output form the upscaling unet isn't nearly as "swirly", but I am noticing red or green bits of noise on the output images:

(Top row unet1, second unet2) imagen_2_262

I added a torch.clamp(0, 1.0) after the image is created in Imagen.sample(), but that didn't seem to help. Any ideas where the noise is coming from?

lucidrains commented 2 years ago

cool! so i think it is also worth trying the old discrete version by setting https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py#L1231 to False and seeing if the artifacts are still there

i'm still slightly concerned about https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py#L248 , so you could also try setting all noise schedulers to be linear (turn off the cosine schedule)

lucidrains commented 2 years ago

also could try turning off the p2 loss https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py#L1232 for the continuous time case (set to 0)

lucidrains commented 2 years ago

@deepglugs thank you for reporting these results, and the artifacts!

lucidrains commented 2 years ago

@deepglugs actually may have identified the issue! let me know if the latest version fixes the problem https://github.com/lucidrains/imagen-pytorch/commit/ba50fc8c5dd8e651b87cab578cf7752bcc34ec91

s4sarath commented 2 years ago

Will there be too much difference between continuous time diffusion vs discrete times?

deepglugs commented 2 years ago

Kept continuous_times=True and updated to 0.1.14. Changing p2_loss_weight_gamma=0 (and nothing else) doesn't seem to affect the noise.

Still has some noise, but it is more pronounced after the up-scaling but not always.

imagen_28_18

Could it have been something introduced between 0.0.60 and 0.1.10? While these samples are not very good, there isn't any noise for the first unet:

imagen_3_6666

lucidrains commented 2 years ago

@deepglugs ohh as i feared, there is still some remaining issues with continuous time then :disappointed: maybe i should default back to discrete ddpm

although it seems like you are getting good results with continuous time (save for the artifacts)! it does appear to make a difference (although i'm not sure how much of it is due to p2 loss)

lucidrains commented 2 years ago

@deepglugs if you could run an experiment with linear only, that would help me narrow down if the issue is with continuous time or just the alpha cosine schedule in continuous time :pray:

deepglugs commented 2 years ago

How do I set linear only? Will setting continous_times=False do that?

lucidrains commented 2 years ago

How do I set linear only? Will setting continous_times=False do that?

you would just set beta_schedules = ('linear',) * num_unets when initializing Imagen

hychiang-git commented 2 years ago

I had this problem and I solve it by setting use_linear_attn = False.

unet1 = Unet(
        dim = 32,
        image_embed_dim = 512,
        num_resnet_blocks = 3,
        dim_mults=(1, 2, 3, 4),
        attn_dim_head = 64,
        layer_attns = (False, True, True, True), 
        use_linear_attn = False,
    )

I think the flag replaces transformers with linear attention layers at this line. However, I am not sure about which setting performs better.

lucidrains commented 2 years ago

@ken012git ohh no, those are two different things

use_linear_attn will supplement with linear attention, anywhere where quadratic full attention is not used

the linear i'm referring to is for the DDPM noise schedule

deepglugs commented 2 years ago

I think it might just be a clamping issue. Before I mistakenly had torch.clamp(img, 0, 1.0) but it should have been img = torch.clamp(img, 0, 1.0) in Imagen.sample() before appending the image to the outputs array.

Trying this with use_linear_attn=False and True the noise is clear even with early samples:

imagen_1_1000

lucidrains commented 2 years ago

@deepglugs oh nice, yea that should be taken care of for you in the latest version https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py#L1445

lucidrains commented 2 years ago

@deepglugs and what i meant was beta_schedules = ('linear', 'linear', 'linear')

lucidrains commented 2 years ago

i should just rename beta_schedules to noise_schedules 🤔

hychiang-git commented 2 years ago

The loss curve seems strange when set use_linear_attn = True and will have noise in early results.

unknown

Screen_Shot_2022-06-12_at_11 13 46_AM
lucidrains commented 2 years ago

@ken012git ohh hmm, do you want to open a new issue for the linear attention issue?

hychiang-git commented 2 years ago

Sure, it's here. Thanks!

deepglugs commented 2 years ago

Noise on first batch of samples looks good with latest main. I'll keep training this test a few more iters and then close this.

lucidrains commented 2 years ago

@deepglugs when you turned off the p2 loss, do you see the same dramatic improvements you noted earlier? just wondering how much of it is due to continuous time vs the p2 loss reweighting

lucidrains commented 2 years ago

@deepglugs depending on how you answer that, i may also port the p2 loss weight over to the discrete time

deepglugs commented 2 years ago

@deepglugs when you turned off the p2 loss, do you see the same dramatic improvements you noted earlier? just wondering how much of it is due to continuous time vs the p2 loss reweighting

I didn't run long enough to determine if it was much better. I can do a longer run if you want once this test is complete.

lucidrains commented 2 years ago

@deepglugs yes, that'd be great, if you have the time! 🙏

lucidrains commented 2 years ago

@deepglugs i'm just wondering if continuous time is better than discrete or just an unimportant detail

deepglugs commented 2 years ago

Ran 6 epochs each. First is default p2_loss: imagen_6_166 p2_loss=0: imagen_6_166

I want to say the default p2_loss helps the upscaler.

lucidrains commented 2 years ago

@deepglugs thank you! i'll probably add the p2 loss as an option for discrete time gaussian diffusion tomorrow then!

marunine commented 2 years ago

For what it's worth, I've been getting good results on my dataset with continuous_times=True and p2_loss_weight_gamma=1.0.

I tend to get color shifting at the default of 0.5 until about my second epoch, whereas it starts converging within the first epoch at 1.0. I think you can play around with it a bit as a hyperparameter per the paper where it comes from, but most of the experiments they ran seem to had it at 1.0.

lucidrains commented 2 years ago

@marunine @deepglugs done! https://github.com/lucidrains/imagen-pytorch/commit/110855e11109c12c4c458d42d65399e872e7d0d3 thank you both for sharing your experimental results!