LingxiaoYang2023 / DSG2024

Official pytorch repository for “Guidance with Spherical Gaussian Constraint for Conditional Diffusion”
43 stars 2 forks source link

Apply DSG on DDPM #8

Open Chenrf1121 opened 1 month ago

Chenrf1121 commented 1 month ago

Can we apply DSG on DDPM? like add some code in

@torch.no_grad()
def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
    b, *_, device = *x.shape, x.device
    model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised)
    noise = noise_like(x.shape, device, repeat_noise)
    # no noise when t == 0
    nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
    return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise

in Non_linear_problems/SD_style/ldm/models/diffusion/ddpm.py line243-250. If yes, would you provide an example?

LingxiaoYang2023 commented 1 month ago

Thanks for your interest in our DSG. Our implementation is based on DDIM with eta=1, which is equal to DDPM. If you have to use the code in DDPM, I recommend returning posterior_mean and variance separately (like class DDPM(Line 403-420) in Linear_Inverse_Problems/guided_diffusion/gaussian_diffusion.py), then apply DSG (like class DSG(Line 101-136) in Linear_Inverse_Problems/guided_diffusion/conditional_methods.py).