lucidrains / imagen-pytorch

Implementation of Imagen, Google's Text-to-Image Neural Network, in Pytorch
MIT License
7.97k stars 753 forks source link

Possibly broken unet2+ #165

Closed deepglugs closed 2 years ago

deepglugs commented 2 years ago

It seems that unet1 progresses quickly. On my dataset, within 10-20 epochs I can get pretty good results at 64px:

image

But when trying unet2, the results are poor even after more than 20-40 epochs:

image

Here's another example of unet2 of mel-spectograms after many many steps:

imagen_158_84_loss0 02229

Granted, it is possible that unet1 results also look similar, but it's hard to tell with the blur going on. I assume that's the result of the noising function? Is there a way to turn off the noising functionality during sample if stop_at_unet is specified at 1? That said, with the spectogram, unet1 produces very good blacks in the padding area, but unet2 can't get it right.

I know other's have voiced issues with unet2 as well. Here are my unet2 settings for reference:

unet2 = dict(
            dim=128,
            cond_dim=512,
            dim_mults=(1, 2, 3, 4),
            cond_images_channels=cond_images_channels,
            num_resnet_blocks=2,
            layer_attns=(False, False, False, True,
            layer_cross_attns=(True, True, True, True),
            # final_conv_kernel_size=1,
            memory_efficient=True
        )

Note: I've also tried with dim_mults (1, 2, 4, 6) and num_resnet_blocks=(2, 2, 4, 8) with similar results.

integer753 commented 2 years ago

I'm also having problems upscaling. I'm on version 0.25.4, i have a well trained unet1 that's giving good results. Training unet2 for 200k steps (batch size 64) gives me very noisy output: image When i trained on version 0.8.8, i had excellent output and it was never noisy and ultimately had great results: image

I'm trying the absolute latest version now, but it always takes a while before unet1 is trained so it will be a while.

My 0.25.4 config for unet2:

unet2 = Unet(
    dim = 80,
    cond_dim = 512,
    dim_mults = (1, 2, 4, 8),
    num_resnet_blocks = (2, 4, 8, 8),
    cond_images_channels = 3,
    layer_attns = (False, False, False, True),
    layer_cross_attns = (False, False, False, True),
    memory_efficient = True,
)

My config in 0.8.8 is exactly the same, except i was only using dim 64 there. (So even with such a low dim size, i had pretty decent results at the time)

lucidrains commented 2 years ago

@deepglugs which version of the library are you on and is this elucidated or non-elucidated? also, how high is your cond_scale for unet2? could you try a lower cond_scale if it is too high?

lucidrains commented 2 years ago

@integer753 i would retry on the latest version, since someone identified a bug with the order of normalization and noising recently

integer753 commented 2 years ago

@integer753 i would retry on the latest version, since someone identified a bug with the order of normalization and noising recently

Ahhh, i looked into it and saw that it was just during sampling, i can confirm i'm getting much better results just by sampling from my existing checkpoint with the latest version! Thank you and thanks a lot for all the work you've done on this lucidrains!

lucidrains commented 2 years ago

@integer753 oh that's great to hear! could you possibly share your results? i get a kick out of seeing what others have trained (but it is ok if you can't)

lucidrains commented 2 years ago

@integer753 also, just for my reference, you are using non-elucidated imagen?

integer753 commented 2 years ago

Yes i'm using the non-elucidated imagen at the moment. Here is one of the results i have at the moment, but i need to retrain a little bit because i couldn't transfer all my layers to the new version, so things are a bit mangled, i'll post something when i have better results. In any case the upscaler is working for me now!

image

lucidrains commented 2 years ago

@integer753 looks really good i agree! :100: :heart:

deepglugs commented 2 years ago

My samples don't look too bad either after lots of training. I've switched back and forth between random cropped unet2 training and full-size unet2 training and I think that has helped. Images aren't perfect still, but there appears to be progress. Closing for now.

image

lucidrains commented 2 years ago

@deepglugs haha yea, we figured that out over at dalle2-pytorch as well

the super resoluting unets actually take a lot more training to get good results!

lucidrains commented 2 years ago

@deepglugs seems like state-of-the-art text to image is secured, onwards to text to video!