dome272 / Diffusion-Models-pytorch

Pytorch implementation of Diffusion Models (https://arxiv.org/pdf/2006.11239.pdf)
Apache License 2.0
1.11k stars 256 forks source link

model generating bad random images #33

Open NITHISHM2410 opened 1 year ago

NITHISHM2410 commented 1 year ago

I trained my diffusion model in tensorflow based on this implementation and and after training for 450 epochs(on landscape dataset) ,my loss was around 0.015 (mse) and I generated a few samples and generated ones were very bad or random. Below are the generated images for 1000 time steps .

I just want to know is this a training issue , does my model need more training to further reduce the loss (currently : 0.015) OR is the problem in sampling technique.Can anyone help me please?

image

Tianchong-Jiang commented 9 months ago

Hi, I encountered the same issue.

randomaccess2023 commented 4 months ago
def sample(self, model, n):
        logging.info(f"Sampling {n} new images....")
        model.eval()
        with torch.no_grad():
            x = torch.randn((n, 3, self.img_size, self.img_size)).to(self.device)
            for i in tqdm(reversed(range(1, self.noise_steps)), position=0):
                t = (torch.ones(n) * i).long().to(self.device)
                predicted_noise = model(x, t)
                alpha = self.alpha[t][:, None, None, None]
                alpha_hat = self.alpha_hat[t][:, None, None, None]
                beta = self.beta[t][:, None, None, None]
                if i > 1:
                    noise = torch.randn_like(x)
                else:
                    noise = torch.zeros_like(x)
                x = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise) + torch.sqrt(beta) * noise
        model.train()
        x = (x.clamp(-1, 1) + 1) / 2
        x = (x * 255).type(torch.uint8)
        return x

Make sure you use the same variable (I am talking about the variable x here) name in these lines:

x = torch.randn((n, 3, self.img_size, self.img_size)).to(self.device)
predicted_noise = model(x, t)
noise = torch.randn_like(x)
noise = torch.zeros_like(x)
x = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise) + torch.sqrt(beta) * noise
x = (x.clamp(-1, 1) + 1) / 2
x = (x * 255).type(torch.uint8)
return x

If you don't, then you will get noise instead of meaningful images.