dome272 / Diffusion-Models-pytorch

Pytorch implementation of Diffusion Models (https://arxiv.org/pdf/2006.11239.pdf)
Apache License 2.0
1.11k stars 256 forks source link

Confuse in the 'Sample Fuction' #42

Open ankan8145 opened 5 months ago

ankan8145 commented 5 months ago
def sample(self, model, n):
        logging.info(f"Sampling {n} new images....")
        model.eval()
        with torch.no_grad():
            x = torch.randn((n, 3, self.img_size, self.img_size)).to(self.device)
            for i in tqdm(reversed(range(1, self.noise_steps)), position=0):
                t = (torch.ones(n) * i).long().to(self.device)
                predicted_noise = model(x, t)
                alpha = self.alpha[t][:, None, None, None]
                alpha_hat = self.alpha_hat[t][:, None, None, None]
                beta = self.beta[t][:, None, None, None]
                if i > 1:
                    noise = torch.randn_like(x)
                else:
                    noise = torch.zeros_like(x)
                x = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise) + torch.sqrt(beta) * noise
        model.train()
        x = (x.clamp(-1, 1) + 1) / 2
        x = (x * 255).type(torch.uint8)
        return x 

Can anyone explain this fuction . In the line 'x = torch.randn((n, 3, self.img_size, self.img_size)).to(self.device)', you create a random image then from that image you predict the noise ( i.e. predicted_noise = model(x, t) ). Are you tring to create an image from a random tensor ??

randomaccess2023 commented 4 months ago

@ankan8145 x = torch.randn((n, 3, self.img_size, self.img_size)).to(self.device) is a random noise with n=12, channels=3 and self.img_size=64 (probably). Random sampling (after training the DDPM) is carried out by starting from pure noise and reversing the time steps (i.e., starting from 1000 and finishing at 1).

The outcome is realistic-looking images obtained from pure noise only.