Closed WN1695173791 closed 2 years ago
No, but its not that different if you use Stratonovich.
The main difference is the fact that you only draw a new z when accepting. And then if its a forward-time SDE, the -hf(.) becomes + hf(.).
So in the torch version https://github.com/AlexiaJM/score_sde_fast_sampling/blob/main/sde_sampling_torch.py, you remove
z = torch.randn_like(x)
and add
z_new = torch.randn_like(x)
z = torch.where(accept, z_new, z)
If forward time, you just change
K1_mean = -h_ * drift
K2_mean = -h_*drift_Heun
drift_Heun, diffusion_Heun = my_rsde(x + K1, t - h)
by
K1_mean = h_ * drift
K2_mean = h_*drift_Heun
drift_Heun, diffusion_Heun = my_rsde(x + K1, t + h)
Is the process of algorithm 2 in the paper implemented in the code?