lucidrains / DALLE2-pytorch

Implementation of DALL-E 2, OpenAI's updated text-to-image synthesis neural network, in Pytorch
MIT License
11.14k stars 1.09k forks source link

Need help with decoder training #40

Closed xiankgx closed 2 years ago

xiankgx commented 2 years ago

I'm training on the CC3M. Is there anything wrong with my training? The loss seems to be going down way too fast and despite the low training loss values, sampling doesn't seem to show it is working. Sampling during training by just calling decoder.sample() giving it the CLIP image embeddings of the minibatch training images. Since I'm training a decoder with two Unets and just training the first Unet for now, I'm breaking out after sampling from the first Unet.

decoder_training_loss

Theses are the samples at the 0k, 5k, 13k, 16k, and 17k training steps.

0k 5k 13k 16k 17k

lucidrains commented 2 years ago

@xiankgx thank you for testing it out! so based on my experience with DDPMs, the loss has to hit around 0.05 before sampled images come into view. how large are your batch sizes?

i'll be testing out the decoder training too, probably tomorrow evening-ish, as i'm planning to copy the cascading unet into the other ddpm repository i have (and test out training on oxford flowers dataset)

lucidrains commented 2 years ago

@xiankgx also, i have yet to code this up, but ideally during sampling, we use the exponentially moving averaged unets https://github.com/lucidrains/DALLE2-pytorch/blob/main/dalle2_pytorch/train.py#L94 still thinking about how to architect this

xiankgx commented 2 years ago

@xiankgx thank you for testing it out! so based on my experience with DDPMs, the loss has to hit around 0.05 before sampled images come into view. how large are your batch sizes?

i'll be testing out the decoder training too, probably tomorrow evening-ish, as i'm planning to copy the cascading unet into the other ddpm repository i have (and test out training on oxford flowers dataset)

Batch size is 128, precision at fp32. I guess I was just not patient enough?

xiankgx commented 2 years ago

@xiankgx also, i have yet to code this up, but ideally during sampling, we use the exponentially moving averaged unets https://github.com/lucidrains/DALLE2-pytorch/blob/main/dalle2_pytorch/train.py#L94 still thinking about how to architect this

I think we can still use the non-EMA one for testing?

lucidrains commented 2 years ago

@xiankgx thank you for testing it out! so based on my experience with DDPMs, the loss has to hit around 0.05 before sampled images come into view. how large are your batch sizes? i'll be testing out the decoder training too, probably tomorrow evening-ish, as i'm planning to copy the cascading unet into the other ddpm repository i have (and test out training on oxford flowers dataset)

Batch size is 128, precision at fp32. I guess I was just not patient enough?

hmm, that sounds sufficient, but also very likely a bug or two remaining 🐞 that's the risk of being the first one to try a new repository 🤣

lucidrains commented 2 years ago

@xiankgx let me get the EMA-ed sampling code in there, just thought of a way to do it cleanly

lucidrains commented 2 years ago

https://github.com/lucidrains/DALLE2-pytorch/commit/ebe01749ed0a48aa77e236eb609440db0944eada tada

xiankgx commented 2 years ago

This training btw is only until before all the changes in the cond_drop_prob changes in https://github.com/lucidrains/DALLE2-pytorch/issues/38. This is too much changes to follow and ported back manually.

You mentioned that text encodings must be padded to max length, and I see you are padding the text encodings in the code. Do you think it affects the training, if I'm just padding text_encodings at the dataset/dataloader level, and I'm not really conditioning on text_encodings anyway.

lucidrains commented 2 years ago

i think its probably worth leaving out the text encoding conditioning altogether, and the image embedding conditioning too (with a conditional dropout of 1.) until we see that the cascading DDPM can generate things unconditionally (which it should). then i would ascertain image embedding condition works, followed by validating that classifier free guidance improves things later in training (it probably wouldn't do much early on). i would save the text encoding for last, since i think this is not the novelty of the paper and was present even in GLIDE with poor results

lucidrains commented 2 years ago

@xiankgx i've added the ability to shortcut sampling at a certain unet in the cascade btw, just so you have an easier time https://github.com/lucidrains/DALLE2-pytorch/commit/8260fc933a9a5118e18f209314ecfad246a42454 thanks for being the first guinea pig lol

xiankgx commented 2 years ago

@xiankgx i've added the ability to shortcut sampling at a certain unet in the cascade btw, just so you have an easier time 8260fc9 thanks for being the first guinea pig lol

Thank you, ya, this is exactly what I had in my code.

xiankgx commented 2 years ago

This is at 39k iterations, top images is sampled. Bottom images are training images. I'm seeing the word "alamy" in the bottom left image in the generated images, in line with the same image from training images. But I'm still seeing quite alot of plain colored backgrounds.

39k 39k_data

lucidrains commented 2 years ago

@xiankgx when you save the image, are you unnormalizing the image? (back to range of 0 to 1)? https://github.com/lucidrains/DALLE2-pytorch/blob/main/dalle2_pytorch/dalle2_pytorch.py#L101 your sampled images have a consistent colored shift, and i'm not sure how to explain that

lucidrains commented 2 years ago

otherwise, the general shapes and structure looks about right for DDPM-type training!

xiankgx commented 2 years ago

Ya, I'm doing inverse normalization with (t * std) + mean with mean and std both (0.5, 0.5, 0.5).

xiankgx commented 2 years ago

It depends on which specific step. Some steps are still like earlier samples where we have all images orange, red, blue. And all samples from the same step have the same color. Sometimes it is better like the above.

lucidrains commented 2 years ago

@xiankgx ohh got it, it would be nice to see what the text is too, if you are conditioning on that. maybe it has latched onto some color in the sentence, if present

lucidrains commented 2 years ago

what is "alamy"? is it a watermark?

xiankgx commented 2 years ago

Ya, I think it is a watermark. alamy should be some stock photo site.

xiankgx commented 2 years ago

42k, 44k, 49k. Alot of them are like this.

42k 44k 49k

lucidrains commented 2 years ago

when sampling, is each image being conditioned on a different text, or all the same?

xiankgx commented 2 years ago

The model is not conditioned on text. It is conditioning on CLIP image embeddings obtained from training images in the same step. It is not fixed.

lucidrains commented 2 years ago

yea, i think the exponential moving average should help, though what you are seeing is still a bit extreme. hoping there is not a bug in the unet lol

lucidrains commented 2 years ago

@xiankgx got it, i'm not sure what's going on then. do you have the sparse attention turned off? (with the grid attention) that was more of an experimental feature i have yet to test locally

lucidrains commented 2 years ago

@xiankgx i'll join you soon next week, as i port the cascading DDPM over and start doing local experiments on my machine with oxford dataset

xiankgx commented 2 years ago

@xiankgx got it, i'm not sure what's going on then. do you have the sparse attention turned off? (with the grid attention) that was more of an experimental feature i have yet to test locally

Ya, I'm just using the defaults.

lucidrains commented 2 years ago

@xiankgx ok, i have it turned off

i'm not sure what's going on! your batch sizes seem large enough. i guess if you start an exponential moving average run and still see the weird color shifting, i'll be certain there's some remaining bug to squash from the decoder

xiankgx commented 2 years ago

This is what I have from openai stuffs after 1k. 64x64 resolution. Model conditioning on both text and image embeddings.

openai_1k

xiankgx commented 2 years ago

I've looked thru the code a few thousand times and I can't find anything as well.

lucidrains commented 2 years ago

@xiankgx ahh ok, i think i spot the potential problem. so i changed the conditioning in the U-net over to a cross-attention based (in line with latent diffusion paper), but its probably best to leave the time conditioning out of it - let me offer an additional time conditioning the old fashioned way (MLP + sum to hiddens)

xiankgx commented 2 years ago

In the other version, I'm doing both, proj image embed and add to time conditioning or embedding, and proj image embed as 4 tokens and concat with text sequence.

lucidrains commented 2 years ago

@xiankgx it doesn't make sense for time to have to fight for attention with the text. you always need to factor in time

lucidrains commented 2 years ago

@xiankgx ohh no way, and still seeing the same thing? ok, i've got nothing then

you should try the exponential moving average run, that's the only thing i can think of

lucidrains commented 2 years ago

@xiankgx let me add an extra conditioning for time for this sunday morning, since it just makes more sense

lucidrains commented 2 years ago

@xiankgx ok, should be all there, and in addition, may have spotted an error with the initial convnext block (it was actually conditioning in pixel space) i've added an extra initial convolution to make sure that doesn't happen

lucidrains commented 2 years ago

@xiankgx i've added back the old resnet blocks from the original DDPM paper, just in case the convnext blocks are not a good fit for generations (i've had a few other researchers complain in other repositories)

lucidrains commented 2 years ago

@xiankgx it was the convnext blocks :( sorry about that - another researcher and I over at https://github.com/lucidrains/video-diffusion-pytorch finally got to the bottom of it

lucidrains commented 2 years ago

sigh, it looks like batchnorm and groupnorms will have to stick around for a bit longer :cry:

xiankgx commented 2 years ago

sigh, it looks like batchnorm and groupnorms will have to stick around for a bit longer 😢

What about them? You have new findings on this problem? I've stopped training meanwhile.

lucidrains commented 2 years ago

@xiankgx yes, it was the convnext blocks, jumped the gun in adopting it. it is better to stick with resnet blocks for generative work until some other paper proves it out. welcome you to retry if you have the time! :pray: otherwise i'll get into the weeds this weekend

chinoll commented 2 years ago

I have encountered the same problem as you. Have you solved it?

xiankgx commented 2 years ago

The run was run when the code was in its earlier stage when ConvNext blocks was used instead of ResNet blocks. I suspected it could be a problem somewhere in the U-Net causing entire batches to have the same tint or color. I have not retried training with the newer ResNet blocks. Is your U-Net using the newer code with ResNet blocks?

egeozsoy commented 2 years ago

I am training the Decoder with two Unets with DecoderTraining(described in README.MD), using automated mixed precision. My image size is 224x224. I am only managing a batch size of 16 on 12GB VRAM. I was wondering if this is normal as @xiankgx was using a batch size of 128 while working with float32.

xiankgx commented 2 years ago

I am training the Decoder with two Unets with DecoderTraining(described in README.MD), using automated mixed precision. My image size is 224x224. I am only managing a batch size of 16 on 12GB VRAM. I was wondering if this is normal as @xiankgx was using a batch size of 128 while working with float32.

Hehe, I was running on 4 GPUs with 40 GB RAM each. No worries on that.

lucidrains commented 2 years ago

sample-5 0

i do believe the decoder trainer is working! oxford flowers dataset at just 2.5k - training unconditionally

lucidrains commented 2 years ago

sample-9 0

4.5k, yea looks good, petals are taking shape

lucidrains commented 2 years ago

sample-15 0

7k steps, yea i think it is working :) @xiankgx your issue is likely from the convnext blocks + lack of EMA

lucidrains commented 2 years ago

@egeozsoy are you seeing any results? if you confirm it is working on your end, i'm closing this issue

egeozsoy commented 2 years ago

Yes, I am also getting some results. Though for me the older learning rate (3e-4) was working better I think. Just fyi

lucidrains commented 2 years ago

@egeozsoy yup, i'll change it to what they had in the paper, which is around 1e-4 thank you for confirming!