zoubohao / DenoisingDiffusionProbabilityModel-ddpm-

This may be the simplest implement of DDPM. You can directly run Main.py to train the UNet on CIFAR-10 dataset and see the amazing process of denoising.
MIT License
1.48k stars 156 forks source link

``GaussianDiffusionSampler``中``forward``注释的algorithm2指的是什么呢? #37

Closed Yonggie closed 7 months ago

Yonggie commented 7 months ago

你好,GaussianDiffusionSamplerforward注释的algorithm2值的是什么?或者有出处吗?

Hi, would u kindly explain what the algorithm is in notation in forward of GaussianDiffusionSampler? Or point out any references?

# In file Diffusion.py
class GaussianDiffusionSampler(nn.Module):
   #...
    def forward(self, x_T):
        """
        Algorithm 2.
        """
        x_t = x_T
        for time_step in reversed(range(self.T)):
            print(time_step)
            t = x_t.new_ones([x_T.shape[0], ], dtype=torch.long) * time_step
            mean, var= self.p_mean_variance(x_t=x_t, t=t)
            # no noise when t == 0
            if time_step > 0:
                noise = torch.randn_like(x_t)
            else:
                noise = 0
            x_t = mean + torch.sqrt(var) * noise
            assert torch.isnan(x_t).int().sum() == 0, "nan in tensor."
        x_0 = x_t
        return torch.clip(x_0, -1, 1)   
Yonggie commented 7 months ago

图片 似乎就是博客里面的这个。

这里面self.p_mean_variance(x_t=x_t, t=t)p_mean_variance的意思是previous mean variance吗?

Yonggie commented 7 months ago

已经理解。谢谢。