CDCgov / multisignal-epi-inference

Python package for statistical inference and forecast of epi models using multiple signals
https://cdcgov.github.io/multisignal-epi-inference/
10 stars 1 forks source link

Planning for handling/tracking of deterministic nodes #168

Open damonbayer opened 3 weeks ago

damonbayer commented 3 weeks ago

So far, when we have wanted to track a "generated quantity" (a quantity which is not sampled directly, but is depends on quantities that are sampled), we have littered numpyro.Deterministic throughout the model code. This can be confusing because the name supplied in numpyro.Deterministic may not correspond to the variable with that name in other parts of the code. E.g. there may be a numpyro site called Rt but later in the model the variable Rt is padded. This padded Rt will not be present in the posterior samples, even though it would be easier to use in post processing than the unpadded Rt.

Perhapes @dylanhmorris can fill us in if there is a more "correct" way to achieve this, but I propose adding a generated_quantities flag to the model arguments, collecting all of the numpyro.Deterministic random variables at the end of the model, and only calling numpyro.Deterministic if generated_quantities = True.

Ex:

import numpy as np
import jax.numpy as jnp
import jax
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive
import matplotlib.pyplot as plt

# Define the model
def linear_regression(X, y=None, gq=False):
    # Priors for unknown parameters
    alpha = numpyro.sample('alpha', dist.Normal(0, 10))
    beta = numpyro.sample('beta', dist.Normal(0, 10))
    sigma = numpyro.sample('sigma', dist.Exponential(1))

    # Linear model
    mean = alpha + beta * X

    # Likelihood (sampling distribution) of observations
    with numpyro.plate('data', X.shape[0]):
        obs = numpyro.sample('obs', dist.Normal(mean, sigma), obs=y)
    if gq:
        numpyro.deterministic('mean', mean)
    return mean

# Generate synthetic data
np.random.seed(0)
N = 100
X = np.random.randn(N)
y = 1.0 + 2.0 * X + np.random.normal(0, 1.0, size=N)

# Define the MCMC model
nuts_kernel = NUTS(linear_regression)
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)
mcmc.run(jax.random.PRNGKey(0), X, y)
posterior_samples = mcmc.get_samples()
posterior_samples.keys()
# Posterior predictive sampling
predictions = predictive(jax.random.PRNGKey(1), X, gq=True)

This ensures that the MCMC objects remain lean and useful for diagnostics, while allowing us to produce generated_quantites in a centralized location.

damonbayer commented 3 weeks ago

After discussion with @dylanhmorris, we think it would be better for now to continue tracking deterministic random variables as we do now (or make them DeterministicVariables), but the models should be revised to actually track them.