lucidrains / imagen-pytorch

Implementation of Imagen, Google's Text-to-Image Neural Network, in Pytorch
MIT License
8.11k stars 768 forks source link

Samples after training only contain white noise. #347

Open Valerie9696 opened 1 year ago

Valerie9696 commented 1 year ago

I wanted to try out training imagen and generating some samples. Therefore, I ran this part of the script from the readme:

import torch from imagen_pytorch import Unet, Imagen, ImagenTrainer

unet for imagen

unet1 = Unet( dim = 32, cond_dim = 512, dim_mults = (1, 2, 4, 8), num_resnet_blocks = 3, layer_attns = (False, True, True, True), )

unet2 = Unet( dim = 32, cond_dim = 512, dim_mults = (1, 2, 4, 8), num_resnet_blocks = (2, 4, 8, 8), layer_attns = (False, False, False, True), layer_cross_attns = (False, False, False, True) )

imagen, which contains the unets above (base unet and super resoluting ones)

imagen = Imagen( unets = (unet1, unet2), text_encoder_name = 't5-large', image_sizes = (64, 256), timesteps = 1000, cond_drop_prob = 0.1 ).cuda()

wrap imagen with the trainer class

trainer = ImagenTrainer(imagen)

mock images (get a lot of this) and text encodings from large T5

text_embeds = my text embeddings#torch.randn(64, 256, 1024).cuda() images = my sample images (3k for a first try) #torch.randn(64, 3, 256, 256).cuda()

feed images into imagen, training each unet in the cascade

loss = trainer( images, text_embeds = text_embeds, unet_number = 1, # training on unet number 1 in this example, but you will have to also save checkpoints and then reload and continue training on unet number 2 max_batch_size = 4 # auto divide the batch of 64 up into batch size of 4 and accumulate gradients, so it all fits in memory )

trainer.update(unet_number = 1)

do the above for many many many many steps now you can sample an image based on the text embeddings from the cascading ddpm

images = trainer.sample(texts = [ 'a puppy looking anxiously at a giant donut on the table', 'the milky way galaxy in the style of monet' ], cond_scale = 3.)

images.shape # (2, 3, 256, 256)

Thereafter, I run the following in order to make my results visible:

for img in images: transform = transforms.ToPILImage() i = transform(img) i.show()

The training is running without an error, but at the end, the only thing that shows are images full of white noise. Does anyone know where I am making the mistake here? Or is 3k simply not enough to get anything out of it? (I know its not much, but unfortunately the maximum that my device can handle). This is one such result.

image

I am really thankfull for any advice on the matter.

kirilllzaitsev commented 1 year ago

One reason could be an insufficient number of steps. Have you tried increasing it to the maximum feasible for your case amount and see if your training loop works?

Try to "remember" a single image first, i.e., optimize for N steps and observe that you can sample this image with high quality.

TheFusion21 commented 1 year ago

The script in the readme only contains mock images (random noise). You need to load a sufficient dataset first and train on it.

Valerie9696 commented 1 year ago

The script in the readme only contains mock images (random noise). You need to load a sufficient dataset first and train on it.

I know, that is why I marked the part with the mock data with my text embeddings and my sample images. This training set currently consists of 3k images and embedded captions.

Valerie9696 commented 1 year ago

One reason could be an insufficient number of steps. Have you tried increasing it to the maximum feasible for your case amount and see if your training loop works?

Try to "remember" a single image first, i.e., optimize for N steps and observe that you can sample this image with high quality.

Oh, I think this might be where I am wrong. How exactly do I increase the amount steps? Currently I am basically running the script above with my own images and embedded captions.

kirilllzaitsev commented 1 year ago

Feel free to put a huge number of steps in the beginning and plot your samples regularly. Notice when your samples start looking like your inputs which will give you intuition about how much time it takes to "overfit" your network.

stepanovD commented 1 year ago

Hey @Valerie9696! I have same problem. Do you have a solution of this problem?

asher-lab commented 1 year ago

Hello @stepanovD and @Valerie9696 can you share how did you create a training data here in imagen? I'm new to pytorch and still trying to connect the dots.