zoubohao / DenoisingDiffusionProbabilityModel-ddpm-

This may be the simplest implement of DDPM. You can directly run Main.py to train the UNet on CIFAR-10 dataset and see the amazing process of denoising.
MIT License
1.48k stars 156 forks source link

Unet是训好的吗?没看见代码里有训练Unet的地方 #31

Open dream-in-night opened 9 months ago

dream-in-night commented 9 months ago

为什么代码里只有训练的噪声损失?Unet怎么搞啊 05AFF5E4

`class GaussianDiffusionTrainer(nn.Module): def init(self, model, beta_1, beta_T, T): super().init()

beta_1 : 1e-4

    # beta_T : 0.02
    # T : 1000
    self.model = model
    self.T = T

    self.register_buffer(
        'betas', torch.linspace(beta_1, beta_T, T).double())
    # self.betas : [1000], from 1e-4 to 0.02, linearly, 1000 points, double type
    alphas = 1. - self.betas
    alphas_bar = torch.cumprod(alphas, dim=0)
    # alphas_bar : [1000], from 1 to 0.98, linearly, 1000 points, double type

    # calculations for diffusion q(x_t | x_{t-1}) and others
    self.register_buffer(
        'sqrt_alphas_bar', torch.sqrt(alphas_bar))
    self.register_buffer(
        'sqrt_one_minus_alphas_bar', torch.sqrt(1. - alphas_bar))

def forward(self, x_0):
    """
    Algorithm 1.
    """
    # x_0 : [80, 3, 32, 32]
    t = torch.randint(self.T, size=(x_0.shape[0], ), device=x_0.device)
    # t : [80]
    noise = torch.randn_like(x_0)
    # extract: 
    x_t = (
        extract(self.sqrt_alphas_bar, t, x_0.shape) * x_0 +
        extract(self.sqrt_one_minus_alphas_bar, t, x_0.shape) * noise)
    x = self.model(x_t, t)
    # x: [80, 3, 32, 32]
    loss = F.mse_loss(x, noise, reduction='none')
    return loss

`

mileret commented 9 months ago

train.py里面 ‘’‘optimizer.zero_grad() x_0 = images.to(device) loss = trainer(x_0).sum() / 1000. loss.backward()’‘’ 就是在训练U-Net啊