pyro-ppl / numpyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
https://num.pyro.ai
Apache License 2.0
2.09k stars 227 forks source link

Got Problems When Computing Log Likelihoods in a Scan-Based VAR Model #1820

Closed richardmkit closed 1 week ago

richardmkit commented 2 weeks ago

I write a vector autoregression model like this. I can successfully do the MLE, but after getting the optimal params to compute the log likelihoods for the Hessian matrix, numpyro.infer.log_likelihood doesn't work for this scan-based model. Details below:

import numpy as np
import pandas as pd
import jax
import jax.numpy as jnp
import jax.scipy as jsp
import numpyro
import numpyro.distributions as dists
import numpyro.distributions.constraints as constraints

from jax import random,grad,hessian
from numpyro.contrib.control_flow import scan
from numpyro.handlers import trace,seed
from numpyro.infer import MCMC,NUTS,SVI,Trace_ELBO,log_likelihood
from numpyro.optim import Adam

def transition(carry,Y_curr,Phi,mu,Sigma):
    Y_prev=carry
    shift=jnp.dot(Phi,Y_prev).flatten()
    mu=mu+shift
    Y=numpyro.sample('Y',dists.MultivariateNormal(mu,Sigma),obs=Y_curr.flatten())
    return Y_curr,Y

def model(Y=None):
    Phi=numpyro.param('Phi',random.normal(random.PRNGKey(0),(2,2)))

    mu=numpyro.param('mu',random.normal(random.PRNGKey(1),(2,)))

    Sigma=random.normal(random.PRNGKey(2),(2,2))
    Sigma=Sigma.at[0,0].set(jnp.abs(Sigma[0,0]))
    Sigma=Sigma.at[1,1].set(jnp.abs(Sigma[1,1]))
    lower=jnp.tril(Sigma)
    diag=jnp.diag(jnp.diagonal(Sigma))
    Sigma=lower+lower.T-diag
    Sigma=numpyro.param('Sigma',Sigma)

    Y_init=Y[0]

    scan(lambda carry,Y_curr:transition(carry,Y_curr,Phi,mu,Sigma),Y_init,Y[1:],len(Y[1:]))

def guide(Y=None):
    pass

numpyro.render_model(model,model_args=(Y,),render_params=True)

optimizer=Adam(0.1)
svi=SVI(model,guide,optimizer,loss=Trace_ELBO())
results=svi.run(random.PRNGKey(0),10000,Y)
params=results[0]

log_likelihood(model,params,Y)['Y']

The model works, I can successfully back out the parameters (I simulated the data, so the parameters are 100% right), but when I tried to compute log likelihoods, It returned:

Array([[nan, nan, nan, ..., nan, nan, nan], [nan, nan, nan, ..., nan, nan, nan]], dtype=float32), shaped (2,10000)

All the params are right, Sigma is also positive definite.

Y's shape is (10001,2,1), except the initial point, as long as the internally-bult function log likelihood from NumPyro understands the model, it should at least return a (10000, ) vector instead of (2,10000), because at each time point, Yt's distribution is a 2D MultivariateNormal, given the observed Yt, which should return a scalar. I am not sure whether the combination of scan and log likelihood is allowed and possible?

Wish for your reply and potential solution to just use the built-in log likelihood function.

Best

fehiepsi commented 2 weeks ago

I'm not sure what's the issue but I guess you want log_likelihood(..., batch_ndims=0). In practice, posterior_samples are different parameters, so it's more pyroic to use

log_likelihood(handlers.substitute(model, params), {}, Y, batch_ndims=0)
fehiepsi commented 1 week ago

Closed. Please use our forum https://forum.pyro.ai/ for questions.