AlexiaJM / score_sde_fast_sampling

Repository for the "Gotta Go Fast When Generating Data with Score-Based Models" paper
105 stars 8 forks source link

algorithm 2 #3

Closed WN1695173791 closed 2 years ago

WN1695173791 commented 2 years ago

Is the process of algorithm 2 in the paper implemented in the code?

AlexiaJM commented 2 years ago

No, but its not that different if you use Stratonovich.

The main difference is the fact that you only draw a new z when accepting. And then if its a forward-time SDE, the -hf(.) becomes + hf(.).

AlexiaJM commented 2 years ago

So in the torch version https://github.com/AlexiaJM/score_sde_fast_sampling/blob/main/sde_sampling_torch.py, you remove

z = torch.randn_like(x)

and add

z_new = torch.randn_like(x)
z = torch.where(accept, z_new, z)

If forward time, you just change

K1_mean = -h_ * drift
K2_mean = -h_*drift_Heun
drift_Heun, diffusion_Heun = my_rsde(x + K1, t - h)

by

K1_mean = h_ * drift
K2_mean = h_*drift_Heun
drift_Heun, diffusion_Heun = my_rsde(x + K1, t + h)