AlexiaJM / score_sde_fast_sampling

Repository for the "Gotta Go Fast When Generating Data with Score-Based Models" paper
105 stars 8 forks source link

Unconditional sampling #6

Open zqOuO opened 1 year ago

zqOuO commented 1 year ago

Hi Alexia ! Thanks for sharing your code. I am trying to use your code for unconditional score net sampling, I found all the used checkpoint are trained in conditional, so i tried set grad = score_fn(x, t)/sigmas[timestep]. It doesn't work and i tried many snr setting. Also, the dict will raise an error if i simply change the model.conditional to false and run the code. How to conduct your code on an unconditional score network? Should I change the model.conditional to false and train another? Thanks.

AlexiaJM commented 1 year ago

Hi @ziqwnudt,

Most people these days use conditional networks, so there is not unconditional checkpoint.

You would indeed need to retrain using model.conditional=True.

From looking at the code, it seems that only ncsnpp.py will correctly remove the conditioning when conditional=False (see https://github.com/AlexiaJM/score_sde_fast_sampling/blob/5da8f3fe103ee5ac3c3a336f16cc06c9541f0ed9/models/ncsnpp.py#L87). I'm not sure why this is the case, but ncsnpp is the best network so hopefully that will be fine to your use-case. Otherwise you will need to modify the file https://github.com/AlexiaJM/score_sde_fast_sampling/blob/5da8f3fe103ee5ac3c3a336f16cc06c9541f0ed9/models/ddpm.py#L64 to allow for unconditional DDPM.

Alexia

zqOuO commented 1 year ago

Thank you Alexia! I think I should train another score network too.Besides, I found you previous work in https://github.com/AlexiaJM/AdversarialConsistentScoreMatching/blob/9575592d3255a4c492728c794fa526dc242e70bc/models/__init__.py#L51, there is a additional modification about the grad that makes grad = grad_n / sigma, you name it 'Gaussian' and 'dae'. What is the meaning of these two targets? I found Gaussian looks so close to unconditional score network which uses sigma(x) = sigma(x,t)/sigma_t.

AlexiaJM commented 1 year ago

Hi @ziqwnudt,

This is because you can set the network to estimates various quantities, as you can always recover the score function from these quantities. The simplest parametrization is for s(x) to estimate the score (-z/std). You can also estimate the noise (z). You can also estimate the real data before noise (x0).

Since x(t) = mu(t)x0 + std(t)z (where mu(t)=0 and std(t) depends on the forward process used). This means that the score is -z/sigma = (mu(t)*x0 - xt)/std(t)^2, so you can get the score from either z, x0, or an estimated score. In practice, people estimate the score directly or z directly, because both tends to work better than estimating x0.

https://github.com/AlexiaJM/AdversarialConsistentScoreMatching/blob/9575592d3255a4c492728c794fa526dc242e70bc/losses/dsm.py#L21

zqOuO commented 1 year ago

Thank you for your reply! I think got the point. Am I right to understand that if we set the target = Gaussian, s(x) will learn -z which is (xt-mu(t)x0)/std(t). while the gradient score is (xt-mu(t)x0)/std(t)^2, so we can recover the gradient score by s(x) = -z/sigma, like the code grad = grad_n / sigma. That is your default target, so all the provided checkpoint with target = Gaussian can recover gradient by grad = grad_n / sigma. Otherwise if I use another score network like NSCN, this s(x) is trained to estimate (xt-mu(t)*x0)/std(t)^2. So I don't need to set grad = grad_n / sigma since this s(x) can recover score directly and grad = grad_n instead of grad_n/sigma is enough to generate new image.

AlexiaJM commented 1 year ago

Hi @ziqwnudt,

Yes exactly! You got it right.

Alexia

zqOuO commented 1 year ago

Thank you!