dome272 / Diffusion-Models-pytorch

Pytorch implementation of Diffusion Models (https://arxiv.org/pdf/2006.11239.pdf)
Apache License 2.0
1.11k stars 256 forks source link

I don't understand the function def noise_images(self, x, t) #24

Open colinsctsang opened 1 year ago

colinsctsang commented 1 year ago

Assume that we have img{0}, img{1}, ..., img_{T}, which are obtained from adding the noise iteratively. I understand that img{t} is given by the formula "sqrt_alphahat * img{0} + sqrt_one_minus_alpha_hat * Ɛ".

However, I don't understand the function "def noise_images(self, x, t)" in [ddpm.py].

It return Ɛ, where Ɛ = torch.randn_like(x). So, this is just a noise signal draw directly from the normal distribution. I suppose this random noise is not related to the input image? It is becasue randn_like() returns a tensor with the same size as input x that is filled with random numbers from a normal distribution with mean 0 and variance 1

In training, the predicted noise is compared to this Ɛ (line 80 in [ddpm.py]).

Why we are predicting this random noise? Shouldn't we predict the noise added at time t, i.e. "img{t} - img{t-1}"?

pushkarjajoria commented 1 year ago

Hey @colinsctsang, the answer to your question is in the math video "Diffusion Models | Paper Explanation | Math Explained" by the same author. I'll try to explain but this is not my area of expertise and I will skip many steps. Again, watch the math video if you want a comprehensive tutorial on this.

We are interested in the forward process

q(x_t | x_t-1) = Normal(x_t; alpha_t x_t-1, beta I)

We do not have x_t-1 but only x_0 (i.e. the initial image). We can keep expanding the x_t-1 into x_t-2 ... x_0 and we land on the equation,

q(x_t|x_o) = N(x_t;√(alpha_hat_t_).x_0, √(1 - alpha_hat_t)I)

Use "Reparameterization Trick" to get the below equation where ε is sampled from a unit normal distribution,

q(x_t|x_0, t) = √(α_bar_t).x_0 + √(1 - α_bar_t)ε

This is a sample from the distribution of images after t timesteps of adding noise.

Now we can train a network to learn the function (aka reverse/denoising process) p(x_0 | x_t, t) which predicts the noise in the image at timestep t. It still does so iteratively but the maths allows us to do this in a single step. I think doing everything iteratively should also work but it will be very slow.

The final objective magically mathematically comes out that we can just compute the MSE between this epsilon and the predicted epsilon (that is the model output).

colinsctsang commented 1 year ago

Thank you very much for your answer. I appreciated it!

I am sorry that I am still confused about this:

In the function "def sample", we can see that the variable "t" is looped backward step by step (i.e., T -> T-1 -> ... -> 1). And this "t" is used as an input in the trained U-Net. Hence, the U-Net is used to recover the image iteratively. In other words, U-Net is an approximation for the noise in p(x_{t-1}|x_t, t).

However, according to your explanation in the training stage, the U-Net is trained to approximate the noise in p(x_0 | x_t, t). So, I suppose U-Net can go directly to the initial state (i.e., the final output). Why do we need to do it iteratively in the "def sample" function when we are trying to generate an output?

To conclude, I am trying to understand why we use U-Net to approximate p(x_0 | xt, t) at the training phase, but it is used as p(x{t-1} | x_t, t) in the testing phase.

Once again, thank you for answering me!

pushkarjajoria commented 1 year ago

We have trained a model which predicts epsilon given the noised image and the timestep. "And this "t" is used as an input in the trained U-Net...for the noise in p(x{t-1}|xt, t)." Correct. The math notation may be incorrect but I am no expert.

"So, I suppose U-Net can go directly to the initial state (i.e., the final output)." Incorrect.

The model at both the training stage and sampling stage is used to predict epsilon given the image x_t and time timestep t. Epsilon is independent of t (as it is sampled from a normal distribution) but the noised image does depend on t. We directly compute MSE between predicted epsilon and actual epsilon at the training time to train the model. At the sampling time, we use x_t and the predicted epsilon along with betas to predict x_t-1 using the formula stated below.

x = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise) + torch.sqrt(beta) * noise

That's my understanding anyways. The 'whys' you'll have to understand from the video.

colinsctsang commented 1 year ago

Thank you very much for your time. I think I understand now.