explainingai-code / StableDiffusion-PyTorch

This repo implements a Stable Diffusion model in PyTorch with all the essential components.
75 stars 13 forks source link

Losses for conditional diffusion models #15

Open vinayak-sharan opened 1 month ago

vinayak-sharan commented 1 month ago

Hey thanks for the videos and codes, I am experimenting with conditional ldms.

Do you happen to have loss plots or logs of the loss? I have a feeling that the loss is decreasing really slowly or not decreasing at all. Could you let me know if you had similar loss decrease? Here is the screenshot for your reference.

Screenshot 2024-05-20 at 13 54 27
explainingai-code commented 1 month ago

Hi @vinayak-sharan , Could you tell me which dataset are you training it on and what is the conditioning(text/mask/class) that you are trying ?

vinayak-sharan commented 1 month ago

I am training on CelebHQ with masks and texts both as conditions.

explainingai-code commented 1 month ago

I dont have any logs but I think for this case(CelebHQ conditioned on masks and texts), by 50 epochs you should get decent generation output. By any chance have you generated samples from the currently trained model. And also how was the autoencoder output, were you able to have the autoencoder part trained such that the reconstructions decent enough ?

vinayak-sharan commented 1 month ago

Hey Tushar, I trained the ldm for 200 epochs and plotted the loss. The VQ-VAE samples are quite good. But the ldm sample is not what I expected :D

Loss over the epochs, I noticed that it start increasing back after like 100 epochs. I am surprised, since it's a plot of train loss, it should be doing overfitting.

epoch_199

VQ-VAE samples: current_autoencoder_sample_781

LDM samples: x0_0

vinayak-sharan commented 1 month ago

Here are the checkpoints in case you are interested: https://drive.google.com/drive/folders/1N2lRCFKz-fshPs3hzIV7ym_gs9kkYmTT?usp=sharing

explainingai-code commented 1 month ago

I was never able to train for more than 100 epochs(cause of compute limitations), but the issue of increase in loss, I think should be reduced by adding a decay in learning rate, so maybe try with that. But more importantly the overall loss decrease is very less and with mask I was able to get higher quality outputs with lesser epochs of training(atleast for the common poses like first and last images of your sample). Could you train a model with just mask conditioning and during training double check if the mask used is indeed the correct one in the data loader by comparing the input image and mask. Also if you made any modifications in the code/config could you share that as well.