lucidrains / imagen-pytorch

Implementation of Imagen, Google's Text-to-Image Neural Network, in Pytorch
MIT License
8.03k stars 758 forks source link

Is the text conditioning verified? #201

Closed xiankgx closed 2 years ago

xiankgx commented 2 years ago

Hi, I've been experimenting with some training and I am able to get some images sampled during training. However, it seems that the generated images do not correspond to the input text. Just wondering to see if anyone has verified that text conditioning works?

xiankgx commented 2 years ago

It seems that the ResNet blocks in the UNet are the primary way of incorporating condition into the UNet for conditional generation. However, from the ResNet block calls, it seems that the cond variable is not really used anywhere. Also, the ada-in style scale-shift denormalization is perhaps best placed after taking into account the text condition.

With reference:

Resnet block forward: https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py#L696-L714

ResNet block calls: https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py#L1602 https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py#L1615 https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py#L1641


Ok, on further inspection, some resnet blocks do use the cond variable, also https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py#L1612 https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py#L1637

Also, t has text conditioning added to it: https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py#L1589

Also, there are attention blocks.

xiankgx commented 2 years ago

image image image image image

xiankgx commented 2 years ago

This is using latent diffusion with channels=4 and using an VAE (autoencoderKL) from (kl-f8) from: https://github.com/CompVis/latent-diffusion/tree/main/models/first_stage_models/kl-f8 https://github.com/CompVis/latent-diffusion/blob/main/scripts/download_first_stages.sh

lucidrains commented 2 years ago

@xiankgx i've heard from multiple researchers that it does work

xiankgx commented 2 years ago

Tried apple2orange dataset from the CycleGAN, results don't look too good.

image image image

xiankgx commented 2 years ago

@lucidrains, please have a look.

Nodja commented 2 years ago

Amateur researcher here, just to report that conditioning does seem to work.

Using my filtered anime dataset and ElucidatedImagen, conditioning doesn't kick in until ~550000 samples seen, or 68750 steps at batch_size 8.

My dataset contains 24k samples of anime girls with different colored hair. 4k of each red, green, blue, and 12k of black, black is just used to improve quality of results. 2k of the red haired girls have cat ears, and 2k of the blue haired girls have cat ears, the goal of my run is to generate green haired catgirls, which the model has never seen.

I should also detail that I'm using Muennighoff/SGPT-1.3B-weightedmean-nli-bitfit and not T5, since I needed to precompute the embeddings anyway due to memory constraints I figured I'd use a bigger text model. I never did a run with T5 so I can't tell you if it makes a difference.

For text I choose one of 4 permutations of a phrase that means the same thing, e.g. one of: "a drawing of a red haired anime girl", "an anime illustration of a girl with red hair", etc. text is picked randomly during the dataset's __getitem__ call. I also cannot tell you if having different phrases for the same image improves things as I only tested with, there's also a 5th "permutation" that is just the raw danbooru tags for that image. The tags were there from the beginning for this run, but I added it mid run in a previous experiment and it had no effect on conditioning, or anything else really.

Here's the results.

Prompts used for sampling The position of the texts will match the grid pattern of the images below. This is here just for the sake of detail, but the important part is the color pattern: RGBRGB RGBRGB | | | | | | | | - | - | - | - | - | - | | drawing of a girl with red hair | drawing of a girl with green hair|drawing of a girl with blue hair|drawing of a catgirl with red hair|drawing of a catgirl with green hair|drawing of a catgirl with blue hair| |danbooru drawing with the following tags:\ngeneral: 1girl, red hair|danbooru drawing with the following tags:\ngeneral: 1girl, green hair|danbooru drawing with the following tags:\ngeneral: 1girl, blue hair|danbooru drawing with the following tags:\ngeneral: 1girl, red hair, cat ears|danbooru drawing with the following tags:\ngeneral: 1girl, green hair, cat ears|danbooru drawing with the following tags:\ngeneral: 1girl, blue hair, cat ears|
At 360000 samples seen cond_scale=1 ![image](https://user-images.githubusercontent.com/7379193/187456336-fe5d9998-a376-4159-982b-0c725c7538db.png) cond_scale=5 ![image](https://user-images.githubusercontent.com/7379193/187456449-356305bd-842c-4e37-82e3-09e393ca5fcd.png)
At 552000 samples seen cond_scale=1 ![image](https://user-images.githubusercontent.com/7379193/187457270-a55e5bc5-ff52-470e-b66e-1b65f6fd0dbe.png) cond_scale=5 ![image](https://user-images.githubusercontent.com/7379193/187457230-db1ba2f9-7081-4970-984d-477bae08dfbd.png)
At 865000 samples seen ### Note: The dataset at this point is no longer just 24k samples, but was increased to ~188k samples, same constrictions apply i.e. no green haired catgirls present in dataset. The increased dataset was used after 577000 samples of the previous one. cond_scale=1 ![image](https://user-images.githubusercontent.com/7379193/187459560-5cc6d402-bb7e-4ac0-b9e2-896579beaf89.png) cond_scale=5 ![image](https://user-images.githubusercontent.com/7379193/187459596-01661ec1-0fd5-4ac8-8cf8-533d754efbbf.png)

As you can see, while the quality of my model is very low, it does seem to be conditioning, I only have a 1080ti so I can't train bigger models at decent speeds. I plan to keep training this run and hopefully the model keeps increases quality enough that we can see some faces and not some disfigured anime girls. At this point I don't even know if GPT was the right choice :)

my config ```python imagen = ElucidatedImagenConfig( unets=[ dict( dim=192, dim_mults=(1, 2, 3, 4), text_embed_dim=2048, num_resnet_blocks=2, layer_attns=(False, True, True, True), memory_efficient=False, self_cond=True, ), ], image_sizes=(64,), cond_drop_prob=0.1, text_embed_dim=2048, num_sample_steps=50, ).create() ```

For reference, I've also trained the base unet without self_cond and with memory_efficient=True and they also conditioned around the same time, using the same text encoder.

xiankgx commented 2 years ago

@Nodja , do you observe that conditioning doesnt seem to work in "early" stages?

xiankgx commented 2 years ago

It seems that at early stages the hair colors don't quite match also.

Nodja commented 2 years ago

Yes, conditioning is mostly ignored for a long time and is only visible after 24+hrs of training on my 1080ti.

xiankgx commented 2 years ago

Hmmm, I see. Guess I just need to be more patient then. Anyway, please keep this thread alive and report if you see anything. Thanks you so much.

lucidrains commented 2 years ago

@xiankgx yeah, just train for longer, more data, the usual

also, bigger the T5, the better the binding. this was the main finding in the Imagen paper, if you recall

lucidrains commented 2 years ago

@xiankgx also make sure your conditioning scale is around 5.

xiankgx commented 2 years ago

cond_scale is 1.0 according to the defaults. Yeah, I noticed it makes a tons of difference in the stable/latent diffusion model that was just released.

lucidrains commented 2 years ago

@xiankgx yup, definitely make it greater than 1.0, that's classifier free guidance, yet another Jonathan Ho work