lucidrains / denoising-diffusion-pytorch

Implementation of Denoising Diffusion Probabilistic Model in Pytorch
MIT License
8.32k stars 1.03k forks source link

Technical question about sampling function #53

Open dome272 opened 2 years ago

dome272 commented 2 years ago

Hi,

I was wondering why every diffusion models implementation uses this specific sampling procedure? When I take a look at the DDPM paper they show the sampling algorithm to be:

algorithm_sampling

However, it seems that no implementation follows that and rather takes a really complicated route of first predicting the noise, then calculating x_0, then the mean and logvariance and then construct x_t-1 from that.

I implemented the above algorithm while using your codebase:

@torch.no_grad()
    def my_sample(self, n):
        x = torch.randn((n, 3, self.image_size, self.image_size)).to(self.device)
        for i in tqdm(reversed(range(1, self.num_timesteps)), position=0):
            t = (torch.ones(n) * i).long().to(self.device)
            predicted_noise = self.denoise_fn(x, t)
            beta = extract(self.betas, t, x.shape)
            alpha_hat = extract(self.alphas_cumprod, t, x.shape)
            alpha = 1. - beta
            if i > 1:
                noise = torch.randn_like(x)
            else:
                noise = torch.zeros_like(x)
            x = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise) + torch.sqrt(beta) * noise
            x = x.clamp(-1., 1.)
        return x.add(1).mul(0.5)

But the results are just gray images with a bit of shape and colour: (top is the normal sampling, like your code, bottom is using the above sampling function) image

Do you have any idea why this kind of sampling does not work?

askerlee commented 2 years ago

This is weird. (predicted noise -> x_0 -> x_t-1) uses eq. 9, and your implementation uses eq.10. I've verified and they are mathematically equivalent.

I'd suggest you to check the first few iterations (largest t) to see if the two routines produce very similar numbers.

malekinho8 commented 2 years ago

Hey @dome272 , I am not sure why the code framework here does not work with the equation you referenced in the paper, and I have not had time to look into in depth, but someone else developed a really nice google colab file that implements the DDPM algorithm step-by-step similar to this code-base, and they do use the equation in the algorithm you referenced above. I have tested their code myself, and it gives good-looking outputs, so I think it could indicate that some detail is not correct with this implementation of the DDPM paper. Linked below is the google colab framework I referenced earlier, feel free to try/experiment with it yourself.

https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/annotated_diffusion.ipynb