VSehwag / minimal-diffusion

A minimal yet resourceful implementation of diffusion models (along with pretrained models + synthetic images for nine datasets)
MIT License
241 stars 38 forks source link

Image Color Problem #3

Closed prajwalsingh closed 1 week ago

prajwalsingh commented 1 year ago

Hi,

Thank you for providing the code. It is helping us a lot.

We train the model for Imagenet Mini generally contains 40 classes (custom dataset). But the output we are getting is having some color issues. We have checked the code many times but are not able to localize the problem.

I would like to know if you ran into a similar problem or if you have any suggestions on what might be the problem.

UNet_cvpr_40-1000_steps-750-sampling_steps-class_condn_True_epoch_480

VSehwag commented 1 year ago

Hi @prajwalsingh This is a perfectly natural phenomenon in the first few epochs (and it goes away after 5-10 epochs). I have noticed that the model starts with generating highly color saturated images in the first 4-5 epochs. This happens more on challenging datasets (e.g., imagenet) where the image distribution is not easy to learn. E.g., it doesn't happen on mnist and you would see the model learning to generate digits in just 2-3 epochs. The color saturation is also more prevalent with super shallower networks (<1M parameters).

Overall I don't consider this as an issue, but a mere artifacts of training trajectory in diffusion models. This is because the phenomenon only happens for few early epochs and the total training epochs in diffusion models tends to be around 500-1000. So by the end of training, the diffusion model learns to generated high fidelity images.

Another subtle reason for color saturation to happen is if you have altered the optimizer. For example, when doing differentially private training with DP-SGD optimizer that clips and adds noise to gradients thus making optimization much harder, I've noticed a similar color saturation problem. There DP-SGD optimization adds high bias and noise to gradients and make it hard to train the network well. So very often the generated image tend to become mono-saturated, i.e., highly saturated in one color. I would double check any issue with optimization, e.g., if you have changed the hyperparameters, in particular learning rate to too high/low value.

prajwalsingh commented 1 year ago

@VSehwag Thank you so much for such a brief explanation and suggestions. I will look into the problem again with this information. I have kept the hyperparameters the same and run it for around 500 epochs.