tqch / ddpm-torch

Unofficial PyTorch Implementation of Denoising Diffusion Probabilistic Models (DDPM)
MIT License
200 stars 35 forks source link

Potential Bug in likelihood calculation #20

Closed HirahTang closed 3 months ago

HirahTang commented 3 months ago

In ddpm_torch/diffusion.py, within the function GaussianDiffusion.calc_all_bpd, the timestep t is defined before the loop and never changes afterwards.

This seems incorrect. Shouldn’t t be updated in each iteration to reflect the current timestep? Here’s the proposed fix:

def calc_all_bpd(self, denoise_fn, x_0, clip_denoised=True):
    B, T = x_0.shape, self.timesteps

    losses = torch.zeros([B, T], dtype=torch.float32)
    mses = torch.zeros([B, T], dtype=torch.float32)

    for i in range(T - 1, -1, -1):
        t = torch.empty([B, ], dtype=torch.int64)
        t.fill_(i)
        x_t = self.q_sample(x_0, t=t)
        loss, pred_x_0 = self._loss_term_bpd(
            denoise_fn, x_0, x_t=x_t, t=t, clip_denoised=clip_denoised, return_pred=True)
        losses[:, i] = loss
        mses[:, i] = flat_mean((pred_x_0 - x_0).pow(2))

    prior_bpd = self._prior_bpd(x_0)
    total_bpd = torch.sum(losses, dim=1) + prior_bpd
    return total_bpd, losses, prior_bpd, mses

This proposed change ensures that t correctly reflects the current timestep during each iteration, which is critical for accurate loss and bpd calculations in DDPMs.