yang-song / score_sde_pytorch

PyTorch implementation for Score-Based Generative Modeling through Stochastic Differential Equations (ICLR 2021, Oral)
https://arxiv.org/abs/2011.13456
Apache License 2.0
1.68k stars 309 forks source link

PC sampler mismatched? #50

Open mh-nguyen712 opened 10 months ago

mh-nguyen712 commented 10 months ago

Hello, thanks for your interesting work!

I have a question about your implementation of PC sampler:

  def pc_sampler(model):
        with torch.no_grad():
            # Initial sample
            x = sde.prior_sampling(shape).to(device)
            timesteps = torch.linspace(sde.T, eps, sde.N, device=device)

            for i in range(sde.N):
                t = timesteps[i]
                vec_t = torch.ones(shape[0], device=t.device) * t
                x, x_mean = corrector_update_fn(x, vec_t, model=model)
                x, x_mean = predictor_update_fn(x, vec_t, model=model)

            return inverse_scaler(x_mean if denoise else x), sde.N * (n_steps + 1)

Why you start by correcter instead of predictor as in Alg 1. of your original paper? Is there any reason? Thank you very much!

LiChenda commented 7 months ago

I have the same question after reviewing these lines of code. @NguyenHai7120 Have you figured it out?

Long-louis commented 6 months ago

Is it because after the prior sampling, you need to run a corrector first?

PhilippHoellmer commented 4 months ago

Same question here. I guess it makes somewhat sense to use the corrector on the prior sampling. But even then the corrector is never applied to the final sample?