Closed xiankgx closed 2 years ago
@xiankgx thank you for testing it out! so based on my experience with DDPMs, the loss has to hit around 0.05 before sampled images come into view. how large are your batch sizes?
i'll be testing out the decoder training too, probably tomorrow evening-ish, as i'm planning to copy the cascading unet into the other ddpm repository i have (and test out training on oxford flowers dataset)
@xiankgx also, i have yet to code this up, but ideally during sampling, we use the exponentially moving averaged unets https://github.com/lucidrains/DALLE2-pytorch/blob/main/dalle2_pytorch/train.py#L94 still thinking about how to architect this
@xiankgx thank you for testing it out! so based on my experience with DDPMs, the loss has to hit around 0.05 before sampled images come into view. how large are your batch sizes?
i'll be testing out the decoder training too, probably tomorrow evening-ish, as i'm planning to copy the cascading unet into the other ddpm repository i have (and test out training on oxford flowers dataset)
Batch size is 128, precision at fp32. I guess I was just not patient enough?
@xiankgx also, i have yet to code this up, but ideally during sampling, we use the exponentially moving averaged unets https://github.com/lucidrains/DALLE2-pytorch/blob/main/dalle2_pytorch/train.py#L94 still thinking about how to architect this
I think we can still use the non-EMA one for testing?
@xiankgx thank you for testing it out! so based on my experience with DDPMs, the loss has to hit around 0.05 before sampled images come into view. how large are your batch sizes? i'll be testing out the decoder training too, probably tomorrow evening-ish, as i'm planning to copy the cascading unet into the other ddpm repository i have (and test out training on oxford flowers dataset)
Batch size is 128, precision at fp32. I guess I was just not patient enough?
hmm, that sounds sufficient, but also very likely a bug or two remaining 🐞 that's the risk of being the first one to try a new repository 🤣
@xiankgx let me get the EMA-ed sampling code in there, just thought of a way to do it cleanly
This training btw is only until before all the changes in the cond_drop_prob changes in https://github.com/lucidrains/DALLE2-pytorch/issues/38. This is too much changes to follow and ported back manually.
You mentioned that text encodings must be padded to max length, and I see you are padding the text encodings in the code. Do you think it affects the training, if I'm just padding text_encodings at the dataset/dataloader level, and I'm not really conditioning on text_encodings anyway.
i think its probably worth leaving out the text encoding conditioning altogether, and the image embedding conditioning too (with a conditional dropout of 1.) until we see that the cascading DDPM can generate things unconditionally (which it should). then i would ascertain image embedding condition works, followed by validating that classifier free guidance improves things later in training (it probably wouldn't do much early on). i would save the text encoding for last, since i think this is not the novelty of the paper and was present even in GLIDE with poor results
@xiankgx i've added the ability to shortcut sampling at a certain unet in the cascade btw, just so you have an easier time https://github.com/lucidrains/DALLE2-pytorch/commit/8260fc933a9a5118e18f209314ecfad246a42454 thanks for being the first guinea pig lol
@xiankgx i've added the ability to shortcut sampling at a certain unet in the cascade btw, just so you have an easier time 8260fc9 thanks for being the first guinea pig lol
Thank you, ya, this is exactly what I had in my code.
This is at 39k iterations, top images is sampled. Bottom images are training images. I'm seeing the word "alamy" in the bottom left image in the generated images, in line with the same image from training images. But I'm still seeing quite alot of plain colored backgrounds.
@xiankgx when you save the image, are you unnormalizing the image? (back to range of 0 to 1)? https://github.com/lucidrains/DALLE2-pytorch/blob/main/dalle2_pytorch/dalle2_pytorch.py#L101 your sampled images have a consistent colored shift, and i'm not sure how to explain that
otherwise, the general shapes and structure looks about right for DDPM-type training!
Ya, I'm doing inverse normalization with (t * std) + mean with mean and std both (0.5, 0.5, 0.5).
It depends on which specific step. Some steps are still like earlier samples where we have all images orange, red, blue. And all samples from the same step have the same color. Sometimes it is better like the above.
@xiankgx ohh got it, it would be nice to see what the text is too, if you are conditioning on that. maybe it has latched onto some color in the sentence, if present
what is "alamy"? is it a watermark?
Ya, I think it is a watermark. alamy should be some stock photo site.
42k, 44k, 49k. Alot of them are like this.
when sampling, is each image being conditioned on a different text, or all the same?
The model is not conditioned on text. It is conditioning on CLIP image embeddings obtained from training images in the same step. It is not fixed.
yea, i think the exponential moving average should help, though what you are seeing is still a bit extreme. hoping there is not a bug in the unet lol
@xiankgx got it, i'm not sure what's going on then. do you have the sparse attention turned off? (with the grid attention) that was more of an experimental feature i have yet to test locally
@xiankgx i'll join you soon next week, as i port the cascading DDPM over and start doing local experiments on my machine with oxford dataset
@xiankgx got it, i'm not sure what's going on then. do you have the sparse attention turned off? (with the grid attention) that was more of an experimental feature i have yet to test locally
Ya, I'm just using the defaults.
@xiankgx ok, i have it turned off
i'm not sure what's going on! your batch sizes seem large enough. i guess if you start an exponential moving average run and still see the weird color shifting, i'll be certain there's some remaining bug to squash from the decoder
This is what I have from openai stuffs after 1k. 64x64 resolution. Model conditioning on both text and image embeddings.
I've looked thru the code a few thousand times and I can't find anything as well.
@xiankgx ahh ok, i think i spot the potential problem. so i changed the conditioning in the U-net over to a cross-attention based (in line with latent diffusion paper), but its probably best to leave the time conditioning out of it - let me offer an additional time conditioning the old fashioned way (MLP + sum to hiddens)
In the other version, I'm doing both, proj image embed and add to time conditioning or embedding, and proj image embed as 4 tokens and concat with text sequence.
@xiankgx it doesn't make sense for time to have to fight for attention with the text. you always need to factor in time
@xiankgx ohh no way, and still seeing the same thing? ok, i've got nothing then
you should try the exponential moving average run, that's the only thing i can think of
@xiankgx let me add an extra conditioning for time for this sunday morning, since it just makes more sense
@xiankgx ok, should be all there, and in addition, may have spotted an error with the initial convnext block (it was actually conditioning in pixel space) i've added an extra initial convolution to make sure that doesn't happen
@xiankgx i've added back the old resnet blocks from the original DDPM paper, just in case the convnext blocks are not a good fit for generations (i've had a few other researchers complain in other repositories)
@xiankgx it was the convnext blocks :( sorry about that - another researcher and I over at https://github.com/lucidrains/video-diffusion-pytorch finally got to the bottom of it
sigh, it looks like batchnorm and groupnorms will have to stick around for a bit longer :cry:
sigh, it looks like batchnorm and groupnorms will have to stick around for a bit longer 😢
What about them? You have new findings on this problem? I've stopped training meanwhile.
@xiankgx yes, it was the convnext blocks, jumped the gun in adopting it. it is better to stick with resnet blocks for generative work until some other paper proves it out. welcome you to retry if you have the time! :pray: otherwise i'll get into the weeds this weekend
I have encountered the same problem as you. Have you solved it?
The run was run when the code was in its earlier stage when ConvNext blocks was used instead of ResNet blocks. I suspected it could be a problem somewhere in the U-Net causing entire batches to have the same tint or color. I have not retried training with the newer ResNet blocks. Is your U-Net using the newer code with ResNet blocks?
I am training the Decoder with two Unets with DecoderTraining(described in README.MD), using automated mixed precision. My image size is 224x224. I am only managing a batch size of 16 on 12GB VRAM. I was wondering if this is normal as @xiankgx was using a batch size of 128 while working with float32.
I am training the Decoder with two Unets with DecoderTraining(described in README.MD), using automated mixed precision. My image size is 224x224. I am only managing a batch size of 16 on 12GB VRAM. I was wondering if this is normal as @xiankgx was using a batch size of 128 while working with float32.
Hehe, I was running on 4 GPUs with 40 GB RAM each. No worries on that.
i do believe the decoder trainer is working! oxford flowers dataset at just 2.5k - training unconditionally
4.5k, yea looks good, petals are taking shape
7k steps, yea i think it is working :) @xiankgx your issue is likely from the convnext blocks + lack of EMA
@egeozsoy are you seeing any results? if you confirm it is working on your end, i'm closing this issue
Yes, I am also getting some results. Though for me the older learning rate (3e-4) was working better I think. Just fyi
@egeozsoy yup, i'll change it to what they had in the paper, which is around 1e-4
thank you for confirming!
I'm training on the CC3M. Is there anything wrong with my training? The loss seems to be going down way too fast and despite the low training loss values, sampling doesn't seem to show it is working. Sampling during training by just calling decoder.sample() giving it the CLIP image embeddings of the minibatch training images. Since I'm training a decoder with two Unets and just training the first Unet for now, I'm breaking out after sampling from the first Unet.
Theses are the samples at the 0k, 5k, 13k, 16k, and 17k training steps.