lucidrains / denoising-diffusion-pytorch

Implementation of Denoising Diffusion Probabilistic Model in Pytorch
MIT License
7.96k stars 996 forks source link

How to use guidance to guide ddpm? #182

Open chagelo opened 1 year ago

chagelo commented 1 year ago

I have trained ddpm for a long time, then I select a image in training set, then add some noise or do some transform to this image, then use it to guide ddpm, but result is not good

def ddim_sample(self, shape, guide=None, mask=None, clip_denoised = True):
        batch, device, total_timesteps, sampling_timesteps, eta, objective = shape[0], self.betas.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective

        times = torch.linspace(-1, total_timesteps - 1, steps=sampling_timesteps + 1)   # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps
        times = list(reversed(times.int().tolist()))
        time_pairs = list(zip(times[:-1], times[1:])) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)]

        img = torch.randn(shape, device = device)

        x_start = None

        for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):
            time_cond = torch.full((batch,), time, device=device, dtype=torch.long)
            self_cond = x_start if self.self_condition else None
            pred_noise, x_start, *_ = self.model_predictions(img, time_cond, self_cond, clip_x_start = clip_denoised)

            if time_next < 0:
                img = x_start
                continue

            alpha = self.alphas_cumprod[time]
            alpha_next = self.alphas_cumprod[time_next]

            sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
            c = (1 - alpha_next - sigma ** 2).sqrt()

            noise = torch.randn_like(img)

            img = x_start * alpha_next.sqrt() + \
                  c * pred_noise + \
                  sigma * noise

            # here use the guidance
            if guide is not None:
                time_next_cond = torch.full((batch,), time, device=device, dtype=torch.long)
                guide_t_1 = self.q_sample(guide, time_next_cond)
                img = img * mask + guide_t_1 * (1 - mask)

        img = unnormalize_to_zero_to_one(img)
        return img

Some results are showed below, the left is the origin image, the mid is added some transformation as the guidance, the right is the ddpm result.

sample1

sample1

You see, that the right, ddpm result is not good. But I don't know why.

  1. there are some problems in my code.
  2. my ddpm is not good
    1. training not enough, but I have trained it for 2 days, 500000 steps, 5000images, batch_size=12, image_size=256
    2. training data is not good
Runist commented 1 year ago

Hi, I trained over 10000 steps using the flower classification dataset and obtained a noise map. And my fid score always is 150. Could you tell how to get your results.