pyro-ppl / numpyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
https://num.pyro.ai
Apache License 2.0
2.14k stars 235 forks source link

Allow for calculation of log_likelihood by chain? #645

Closed vanAmsterdam closed 4 years ago

vanAmsterdam commented 4 years ago

When checking inference, diagnosing log_likelihoods by chain could provide some insight, e.g. if the posterior has different modes. Currently passing samples that are group_by_chain-ed to log_likelihood will lead to broadcasting errors. Maybe we can add something like this:

import jax, numpyro
from jax import numpy as np, random, vmap
from numpyro import distributions as dist, sample
from numpyro.infer.mcmc import MCMC, NUTS
from numpyro.infer.util import log_likelihood
from functools import partial
numpyro.set_host_device_count(2)

rkeys = random.split(random.PRNGKey(0), 2)
X     = random.normal(rkeys[0], shape=(100,))
Y     = 0.5 * X + random.normal(rkeys[1], shape=(100,))
def m(X, Y = None):
    a = sample('a', dist.Normal())
    b = sample('b', dist.Normal())
    with numpyro.plate('obs', X.shape[0]):
        mu_hat = a + b * X
        sample('obs_Y', dist.Normal(mu_hat, 1), obs=Y)

mcmckey = random.PRNGKey(1)

mcmc = MCMC(NUTS(m), 100, 250, num_chains=2)
mcmc.run(mcmckeys[0], X, Y)

smps  = mcmc.get_samples(group_by_chain=True)
# lls  = log_likelihood(m, smps, X, Y) # -> throws an error

def log_likelihood2(model, posterior_samples, grouped_by_chain=False, *args, **kwargs):
    if grouped_by_chain:
        llfun = lambda samples, model, *args, **kwargs: log_likelihood(model, samples, *args, **kwargs)
        return vmap(partial(llfun, model=model, *args, **kwargs))(posterior_samples)
    else:
        return log_likelihood(model, posterior_sampels, *args, **kwargs)

lls = log_likelihood2(m, smps, grouped_by_chain=True, X=X, Y=Y)

print(lls['obs_Y'].shape) # -> (2, 250, 100)

the only thing I have been struggling with still is to keep the possibility to pass *args to the model call...

fehiepsi commented 4 years ago

Thanks, @vanAmsterdam! I guess the easiest workaround is to do smps = mcmc.get_samples(group_by_chain=False), get likelihoods, then reshape the output to num_chains x num_samples x .... The job can be easier if we incorporate an argument batch_ndim or something like that, so users just need to set batch_ndim=1 to get the current behavior, batch_ndim=2 to get chained likelihood, and batch_ndim=0 to get the likelihood of a single sample.

@neerajprad It is a bit inconvenient for users to work with the reshape stuff and to carry around both non-chained samples and chained samples because some utilities work with chained samples (e.g. diagnostic stuff) while the others work with non-chained version (likelihood, predictive). WDYT?

fehiepsi commented 4 years ago

@vanAmsterdam Now, you can compute log_likelihood by chain by specifying 'batch_ndims=2'.