Open Valerie9696 opened 1 year ago
One reason could be an insufficient number of steps. Have you tried increasing it to the maximum feasible for your case amount and see if your training loop works?
Try to "remember" a single image first, i.e., optimize for N steps and observe that you can sample this image with high quality.
The script in the readme only contains mock images (random noise). You need to load a sufficient dataset first and train on it.
The script in the readme only contains mock images (random noise). You need to load a sufficient dataset first and train on it.
I know, that is why I marked the part with the mock data with my text embeddings and my sample images. This training set currently consists of 3k images and embedded captions.
One reason could be an insufficient number of steps. Have you tried increasing it to the maximum feasible for your case amount and see if your training loop works?
Try to "remember" a single image first, i.e., optimize for N steps and observe that you can sample this image with high quality.
Oh, I think this might be where I am wrong. How exactly do I increase the amount steps? Currently I am basically running the script above with my own images and embedded captions.
Feel free to put a huge number of steps in the beginning and plot your samples regularly. Notice when your samples start looking like your inputs which will give you intuition about how much time it takes to "overfit" your network.
Hey @Valerie9696! I have same problem. Do you have a solution of this problem?
Hello @stepanovD and @Valerie9696 can you share how did you create a training data here in imagen? I'm new to pytorch and still trying to connect the dots.
I wanted to try out training imagen and generating some samples. Therefore, I ran this part of the script from the readme:
import torch from imagen_pytorch import Unet, Imagen, ImagenTrainer
unet for imagen
unet1 = Unet( dim = 32, cond_dim = 512, dim_mults = (1, 2, 4, 8), num_resnet_blocks = 3, layer_attns = (False, True, True, True), )
unet2 = Unet( dim = 32, cond_dim = 512, dim_mults = (1, 2, 4, 8), num_resnet_blocks = (2, 4, 8, 8), layer_attns = (False, False, False, True), layer_cross_attns = (False, False, False, True) )
imagen, which contains the unets above (base unet and super resoluting ones)
imagen = Imagen( unets = (unet1, unet2), text_encoder_name = 't5-large', image_sizes = (64, 256), timesteps = 1000, cond_drop_prob = 0.1 ).cuda()
wrap imagen with the trainer class
trainer = ImagenTrainer(imagen)
mock images (get a lot of this) and text encodings from large T5
text_embeds = my text embeddings#torch.randn(64, 256, 1024).cuda() images = my sample images (3k for a first try) #torch.randn(64, 3, 256, 256).cuda()
feed images into imagen, training each unet in the cascade
loss = trainer( images, text_embeds = text_embeds, unet_number = 1, # training on unet number 1 in this example, but you will have to also save checkpoints and then reload and continue training on unet number 2 max_batch_size = 4 # auto divide the batch of 64 up into batch size of 4 and accumulate gradients, so it all fits in memory )
trainer.update(unet_number = 1)
do the above for many many many many steps now you can sample an image based on the text embeddings from the cascading ddpm
images = trainer.sample(texts = [ 'a puppy looking anxiously at a giant donut on the table', 'the milky way galaxy in the style of monet' ], cond_scale = 3.)
images.shape # (2, 3, 256, 256)
Thereafter, I run the following in order to make my results visible:
for img in images: transform = transforms.ToPILImage() i = transform(img) i.show()
The training is running without an error, but at the end, the only thing that shows are images full of white noise. Does anyone know where I am making the mistake here? Or is 3k simply not enough to get anything out of it? (I know its not much, but unfortunately the maximum that my device can handle). This is one such result.
I am really thankfull for any advice on the matter.