dome272 / Diffusion-Models-pytorch

Pytorch implementation of Diffusion Models (https://arxiv.org/pdf/2006.11239.pdf)
Apache License 2.0
1.11k stars 256 forks source link

the model learns very badly on cifar10 32*32 #23

Open RowenaHe opened 1 year ago

RowenaHe commented 1 year ago

Hello! Firstly, thank you for the awesome video and code which explain so well how the diffusion models are implemented in code.

I met some issues during reimplementing your code. I' m wondering if you could give me some advice on how to make it work. I tried to use your code to generate conditional cifar10 (32*32 resolution), but so far the results look kind of bad. In my training, I changed the Unet input size and output size to 32, the corresponding resolution in self-attention, and batch-size to 256. The number of down and up blocks, as well as the bottleneck layers were kept the same as your original setting. After training 400 epochs, the generated images were almost pure color. image

Then, I tried add a warmup learning schedule to 10 times as big as the original lr (1e-4) in the first 1/10 epochs, and a cosine annealing for following epochs, and trained it for 1800 epochs. But the final results still look the same as earlier image

Do you have some ideas on what's wrong with my version of reimplementation ? I would really appreciate any insights.

RowenaHe commented 1 year ago

Hi, I found a way to make this work. The problem is in the sampling process. During sampling, if I clamp x_t to [-1, 1] in every recurrent timestep, the results looks better (at least now it looks like photos)

image

RowenaHe commented 1 year ago

But, still, there might be some other problems. The picture looks very hazy, feels like covered by neon or something.

potatoker commented 1 year ago

@RowenaHe same problem here, did you mean fix it by something like this?:

def noise_images(self, x, t): 
    sqrt_alpha_hat = torch.sqrt(self.alpha_hat[t])[:, None, None, None]
    sqrt_one_minus_alpha_hat = torch.sqrt(1 - self.alpha_hat[t])[:, None, None, None]
    Ɛ = torch.randn_like(x)
    return (sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * Ɛ).clamp(-1, 1), Ɛ
RowenaHe commented 1 year ago

@potatoker Sorry for taking so long to reply.

No, this is not how I fixed this problem. I made no change in the training procedure. The only modification was in the sampling process. It goea something like this:

x = randn(b, c, w, h)
for t in tqdm(reversed(range(opt.timesteps))):
    noise = model(x, t)
    x = ......x....noise......(some function about x and noise)
    x = x.clamp(-1,1)   # this line is my only modification
RowenaHe commented 1 year ago

I output the intermediate data to look for reason. And I think the cause is that, for some unknown reason, during the recurrent sampling process, the trained model cannot keep the the generated data around / within [-1,1].

If no clamping in the recurrent sampling process, after 1000 inference steps, almost every single sample is out of the range of [-1,1]. So simply clamping once after the whole sampling process is over would inevitably get the picture of pure color.

volcanik commented 1 year ago

I am experiencing the same type of issue with 64x64 landscapes. Half of the images are beautiful, but the other half are solid colors, mostly black and white. I believe the reason for this runaway effect where the model can't keep the values between [-1, 1] is actually simply due to the choice of hyperparameters, and the model may also be trying to overfit. Maybe increase the weight decay and decrease the learning rate? The hyperparameters that worked for the 64x64 model may not necessarily work for your 32x32 model.

RowenaHe commented 1 year ago

@volcanik Thank you for your insight! I will notice you if I can somehow solve this problem in my following experiments.

RowenaHe commented 1 year ago

Hi! I found that using the complex, residual sampling procedure (the one Open-AI use in diffuser-samlpers), which was supposed to get the same results as the simple one-line sampling, in fact gets better results! @potatoker @volcanik @dome272

one-line sampling result: image

open-AI DDPM 'residual' sampler result (from the same model): image

volcanik commented 1 year ago

@RowenaHe that is very interesting! I have also done a few small experiments on CIFAR10 64x64. I only trained for around 50 epochs, but I found that increasing the weight decay of AdamW to 0.1 and decreasing the learning rate to 1e-4 completely removed solid color images with the regular sampling algorithm. I think this runaway effect is most likely caused by excessively large weights. As for the learning rate, I'm not really sure if it has made any difference. Please tell me where exactly you got this new sampling procedure? BTW your results look really good now!

RowenaHe commented 1 year ago

@volcanik OpenAI has a public repository, and there are several available plug-in schedulers, including DDPMScheduler, DDIMScheduler, etc. You can simply use "pip install diffusers" to download them. To use them only need use "from diffusers import DDPMScheduler" and "scheduler = DDPMScheduler()", and then replace the sampling with another loop like this: 图片1

There are some great jupyter notebooks demonstrating how we can use the diffusers. I put the link here for you to check out. https://github.com/huggingface/diffusion-models-class

thomasnicolet commented 1 year ago

@RowenaHe Thank you for your suggestions. We are also trying to get this DDPM version to work on our own dataset, and get very bright colors. We will try to clamp x = x.clamp(-1,1) at the end of each timestep (i.e. we do it on x_{t-1}, which we have gotten from x_t).

The scheduler you mention works on a conditional model. Is that the DDPM_conditional from this github? I don't see what y is for example, is that the conditional text_embedding? It looks like a global variable, or am I mistaken?

Dart120 commented 8 months ago

You guys are so awesome! Was having all of the above issues and you helped me solve them ❤️ @RowenaHe @thomasnicolet

Dart120 commented 8 months ago

@RowenaHe Thank you for your suggestions. We are also trying to get this DDPM version to work on our own dataset, and get very bright colors. We will try to clamp x = x.clamp(-1,1) at the end of each timestep (i.e. we do it on x_{t-1}, which we have gotten from x_t).

The scheduler you mention works on a conditional model. Is that the DDPM_conditional from this github? I don't see what y is for example, is that the conditional text_embedding? It looks like a global variable, or am I mistaken?

I think y is the conditional label

javiersgjavi commented 7 months ago

@Dart120, which solution was the final one? I have already tried using the AdamW and the learning rate suggested by @volcanik. Still, while I have seen an improvement in the general quality of the generated samples, I continue obtaining some images full of one colour.

Dart120 commented 7 months ago

@Dart120, which solution was the final one? I have already tried using the AdamW and the learning rate suggested by @volcanik. Still, while I have seen an improvement in the general quality of the generated samples, I continue obtaining some images full of one colour.

Hiya! So what worked for me was clamping the output between -1 and 1 for each timestep when using the cosine schedule. This produces the greyish images that were posted earlier in this thread. Then to fix the greyish images I used code from the hugging face diffusers package which is also a technique posted about in this thread.

I would post my code but it's part of coursework that's currently being marked so I can't for obvious reasons but doing the above two things really helped 😁

javiersgjavi commented 7 months ago

Thank you very much for your answer, @Dart120; I'll try to apply this information to check if I can improve my results. Good luck with your coursework!

javiersgjavi commented 6 months ago

I have finally implemented what you said @Dart120, and it works as expected. Thank you again!

For anyone reaching this thread in the future, the scheduler of diffusers library seems to avoid the problem of values outside range [-1, 1] by using more sophisticated ways when you call the step() function.

https://github.com/huggingface/diffusers/blob/d2fc5ebb958ad8c967752dba5b23c43563cb3159/src/diffusers/schedulers/scheduling_ddpm.py#L462-L468