yang-song / score_sde_pytorch

PyTorch implementation for Score-Based Generative Modeling through Stochastic Differential Equations (ICLR 2021, Oral)
https://arxiv.org/abs/2011.13456
Apache License 2.0
1.58k stars 295 forks source link

Likelihood estimation #59

Open mtailanian opened 1 month ago

mtailanian commented 1 month ago

Hello Dr. @yang-song , thank you very much for this work.

I'm trying to estimate the likelihood for a given sample. I understand I have to do something very similar to what you do for computing the bpd, here.

As I understand, following eq. (39) in the paper, to obtain $\log(p_0(x_0))$ I have to "correct" $\log(p_T(x_T))$ using the integral of the divergence of the drift function: $\int0^T \nabla \cdot \overset{-}{f}\theta (x, t) dt$

In order to obtain a more accurate likelihood estimation using the Skilling-Hutchinson trace estimator, what I'm doing is using the $x$ and $t$ obtained from the SDE solver, like this:

t = solution.t
x = solution.y[:-shape[0], :]

and using these values to plug them into the equation $\epsilon^T \nabla \overset{-}{f}_\theta (x, t) \epsilon $. Then I sample many epsilons and average the results of this equation, to obtain an estimation of div_f.

Finally, I just compute the integral in time, like this:

div_f_integral = torch.trapz(div_f, t, dim=-1)

What do you think, is this correct?

The problem is that the result I'm obtaining is not as expected. When I compute $\log(p_T(x_T)) + \int0^T \nabla \cdot \overset{-}{f}\theta (x, t) dt$, I'm supposed to obtain $\log(p_0(x_0))$, but I obtain nonsense values, like log-probs greater than 0...

In summary, what can I do to obtain a more accurate likelihood estimation?

Many thanks in advance!

And any help or hint is very appretiated

daihuiao commented 1 month ago

Hello,

I am also studying the likelihood computation as described in the paper. However, I noticed that in the original code, the prior log probability (prior_logp) is calculated using this line. This formula calculates prior_logp based on the final output of the ODE solver (see line 101: zp = solution.y[:, -1]), which means the prior log probability is computed by inputting the final denoised image into a Gaussian distribution.

However, according to the formula in the paper, the final likelihood should be the likelihood of the Gaussian distribution plus an integral term. Am I misunderstanding something?

Any response would be greatly appreciated.