lucidrains / denoising-diffusion-pytorch

Implementation of Denoising Diffusion Probabilistic Model in Pytorch
MIT License
7.91k stars 990 forks source link

"pred_x0" preforms better than "pred_noise" #58

Open TheSunWillRise opened 2 years ago

TheSunWillRise commented 2 years ago

Hi, guys. I have conducted experiments on MNIST and FashionMNIST. In my results, training with "objective == pred_x0" achieves better results than "objective == pred_noise" on both datasets, which is contrary to the results in the original paper [1] (they find predicting noise performs best). Does anyone meet a similar situation? Is there any paper discussing the difference between the different objectives?

[1] Ho, J., Jain, A., & Abbeel, P. (2020). Denoising Diffusion Probabilistic Models. ArXiv, abs/2006.11239.

lucidrains commented 2 years ago

@TheSunWillRise this is an interesting finding

are you able to share your experiment results, through weights and biases or other means? how are you measuring 'better'?

lucidrains commented 2 years ago

@TheSunWillRise DALLE2 ended up using the predict x0 objective for the diffusion prior, but that is working in CLIP latent embedding space

malekinho8 commented 2 years ago

Hey @TheSunWillRise, I have also been having trouble training diffusion models using the pred_noise objective, and I wasn't sure if it was something on my end or some issue with the code setup. If you would be able to share some metrics or figures that you have obtained that allowed you to come to this conclusion that would be greatly appreciated.

Personally, I have noticed that for a given noise output from a model trained on an image data set scaled from -1 to 1, the output from predict_start_from_noise( ) will give something that is far beyond the range of -1 to 1. I am not sure if this result is expected, but I imagine that training a model to predict x_0 directly would not result in the data having such an extreme range:

image

Do note that this output refers to timestep 999 (i.e. the last step in the sequence), but we can see that the orange distribution has a much larger variance relative to the noise at step 999 (x_T). This is just one result I have seen, but I just thought I would include it to show that something may be slightly off with regards to the pred_noise objective. I have not trained a model using pred_x_0, but might do so to see if I get a similar result to @TheSunWillRise.

malekinho8 commented 2 years ago

Hey @TheSunWillRise, I have also been having trouble training diffusion models using the pred_noise objective, and I wasn't sure if it was something on my end or some issue with the code setup. If you would be able to share some metrics or figures that you have obtained that allowed you to come to this conclusion that would be greatly appreciated.

Personally, I have noticed that for a given noise output from a model trained on an image data set scaled from -1 to 1, the output from predict_start_from_noise( ) will give something that is far beyond the range of -1 to 1. I am not sure if this result is expected, but I imagine that training a model to predict x_0 directly would not result in the data having such an extreme range:

image

Do note that this output refers to timestep 999 (i.e. the last step in the sequence), but we can see that the orange distribution has a much larger variance relative to the noise at step 999 (x_T). This is just one result I have seen, but I just thought I would include it to show that something may be slightly off with regards to the pred_noise objective. I have not trained a model using pred_x_0, but might do so to see if I get a similar result to @TheSunWillRise.

Just as an update, I ran a test to see what the data range of the x_0 prediction would be if I changed the objective to pred_x0 from pred_noise. For this test, I was using the MNIST handwritten digits data set which I scaled to [-1, 1]:

image

Now the distribution is mostly -1 (as expected) given that the digit itself typically only occupies a small region of the input space. So there is a significant difference between using pred_x0 and pred_noise for the objective, and I would argue that the output from using the pred_x0 objective is better for training a diffusion model as it currently stands. Any insight or updates on if this is an expected result or not would be greatly appreciated @lucidrains .

guozhiyao commented 2 years ago

Hey @TheSunWillRise, I have also been having trouble training diffusion models using the pred_noise objective, and I wasn't sure if it was something on my end or some issue with the code setup. If you would be able to share some metrics or figures that you have obtained that allowed you to come to this conclusion that would be greatly appreciated.

Personally, I have noticed that for a given noise output from a model trained on an image data set scaled from -1 to 1, the output from predict_start_from_noise( ) will give something that is far beyond the range of -1 to 1. I am not sure if this result is expected, but I imagine that training a model to predict x_0 directly would not result in the data having such an extreme range:

image

Do note that this output refers to timestep 999 (i.e. the last step in the sequence), but we can see that the orange distribution has a much larger variance relative to the noise at step 999 (x_T). This is just one result I have seen, but I just thought I would include it to show that something may be slightly off with regards to the pred_noise objective. I have not trained a model using pred_x_0, but might do so to see if I get a similar result to @TheSunWillRise.

I have the same problem, have you solved it?

malekinho8 commented 2 years ago

@guozhiyao , unfortunately not. You have to implement clamping when using pred_noise objective, though no clamping is necessary with pred_x0 objective. Overall, I still found that with clamping, the pred_noise objective gave visually higher quality results.

return-sleep commented 9 months ago

@TheSunWillRise Can you please share the training logs on the minist dataset, e.g. loss at training convergence. I've been trying this experiment recently, but found that the prediction loss oscillates around 0.01 and I'm not sure if this is reasonable. Thanks for your help.