lucidrains / denoising-diffusion-pytorch

Implementation of Denoising Diffusion Probabilistic Model in Pytorch
MIT License
7.79k stars 982 forks source link

DDIM sampler for Continuous time Gaussian diffusion? #245

Open markub3327 opened 1 year ago

markub3327 commented 1 year ago

Hello,

is possible to use the DDIM sampler when the time is continuous (continuous_time_gaussian_diffusion.py)? Please can you provide a simple example of code?

I try to implement it like this (scheduling_ddim.py):


    def _get_variance(self, timestep, prev_timestep):
        alpha_prod_t = self.alphas_cumprod[timestep]
        alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
        beta_prod_t = 1 - alpha_prod_t
        beta_prod_t_prev = 1 - alpha_prod_t_prev

        variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)

        return variance

    # DDIM sampler
    # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
    # Notation (<variable name> -> <name in paper>
    # - pred_noise_t -> e_theta(x_t, t)
    # - pred_original_sample -> f_theta(x_t, t) or x_0
    # - std_dev_t -> sigma_t
    # - eta -> η
    # - pred_sample_direction -> "direction pointing to x_t"
    # - pred_prev_sample -> "x_t-1"
    def ddim_sample(
        self,
        unet,
        image,
        t,
        class_cond,
        cond_scale = 1.,
        eta = 0.0,
        variance_noise = None,
        t_next = None,
    ):
        pred = unet.call_with_cond_scale(
            image,
            self.log_snr(t),
            class_cond,
            cond_scale = cond_scale,
        )

        # 1. compute alphas, betas
        log_snr = self.log_snr(t)
        log_snr_next = self.log_snr(t_next)
        log_snr, log_snr_next = map(partial(right_pad_dims_to, x_t), (log_snr, log_snr_next))

        alpha, sigma = log_snr_to_alpha_sigma(log_snr)
        alpha_next, sigma_next = log_snr_to_alpha_sigma(log_snr_next)

        # alpha_prod_t = self.alphas_cumprod[timestep]
        # alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod

        # beta_prod_t = 1 - alpha_prod_t

        # 2. compute predicted original sample from predicted noise also called
        # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
        if unet.pred_objective == 'noise':
            pred_original_sample = self.predict_start_from_noise(image, t = t, noise = pred)
        elif unet.pred_objective == 'x_start':
            pred_original_sample = pred
        elif unet.pred_objective == 'v':
            pred_original_sample = self.predict_start_from_v(image, t = t, v = pred)
        else:
            raise ValueError(f'unknown objective {unet.pred_objective}')

        # 3. Clip or threshold "predicted x_0"
        pred_original_sample = tf.clip_by_value(pred_original_sample, -1.0, 1.0)

        # 4. compute variance: "sigma_t(η)" -> see formula (16)
        # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
        # variance = self._get_variance(timestep, prev_timestep)
        std_dev_t = eta * sigma  # variance ** (0.5)

        # 5. the pred_epsilon is always re-derived from the clipped x_0 in Glide
        pred_epsilon = (image - alpha * pred_original_sample) / sigma    # alpha_prod_t ** (0.5), beta_prod_t ** (0.5)

        # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
        pred_sample_direction = (1 - alpha_next**2 - std_dev_t**2) ** (0.5) * pred_epsilon     # alpha_prod_t_prev

        # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
        prev_sample = alpha * pred_original_sample + pred_sample_direction     #  alpha_prod_t_prev ** (0.5), 

        if eta > 0:
            if variance_noise is None:
                variance_noise = tf.random.normal(
                    tf.shape(pred), dtype=pred.dtype
                )
            variance = std_dev_t * variance_noise

            prev_sample = prev_sample + variance

        return prev_sample, pred_original_sample

The only issue with this implementation is how calculate the self.alphas_cumprod. Or this is not a good way of thinking?

Thanks.

markub3327 commented 1 year ago

@lucidrains Hi, is this implementation above equivalent of DDIM sampler for continuous time?

danbochman commented 3 months ago

Hi @markub3327 did you ever manage to implement DDIM for continuous time? I am also interested in this

markub3327 commented 3 months ago

@lucidrains Yes, I'm. Why you asking? Did you known something new?

danbochman commented 3 months ago

@markub3327 It would be great if you can share the implementation because I also tried something similar but I am unable to make this work - Or what you posted here worked well for you?

I also tried squaring the alpha and alpha next but still the results are poor...

markub3327 commented 3 months ago

@danbochman Hi, I don't have any else implementation, that is above. And the results I get better with discrete timestep by this implementation. Here is the DDIM sampler with continuous time which may be working well: https://keras.io/examples/generative/ddim/