pyro-ppl / numpyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
https://num.pyro.ai
Apache License 2.0
2.18k stars 239 forks source link

Error when using obs_mask and predictive with different input shape #1847

Closed felipeangelimvieira closed 2 months ago

felipeangelimvieira commented 2 months ago

First of all, thank you for this amazing library.

I've found that Predictive raises an unexpected error when using obs_mask. It happens when a certain shape is passed during SVI inference, but another is used in predictive, maybe related to pyro-ppl/numpyro#1772.

Here it is a code to reproduce it:

import numpyro 
import numpy as np
import jax
from numpyro.infer import MCMC, NUTS, SVI, Predictive, Trace_ELBO
from numpyro.infer.autoguide import AutoDelta
from numpyro.infer.initialization import init_to_mean
from numpyro.infer.svi import SVIRunResult

def model(y, x, obs_mask):

    a = numpyro.sample('a', numpyro.distributions.Normal(0, 1))
    b = numpyro.sample('b', numpyro.distributions.Normal(0, 1))
    sigma = numpyro.sample('sigma', numpyro.distributions.HalfNormal(1))

    mu = a + b * x
    numpyro.sample('y', numpyro.distributions.Normal(mu, sigma), obs=y,obs_mask=obs_mask)

    return None

x = np.random.normal(0, 1, 100)
y = 1 + 2 * x + np.random.normal(0, 1, 100)
obs_mask = np.ones_like(y, dtype=bool)
obs_mask[-20:] = False

guide_ = AutoDelta(model, init_loc_fn=init_to_mean())
svi_ = SVI(model, guide_, numpyro.optim.Adam(step_size=1e-4), loss=Trace_ELBO())
run_results_: SVIRunResult = svi_.run(
    rng_key=jax.random.PRNGKey(24), num_steps=1000, y=y, x=x, obs_mask=obs_mask
)

posterior_samples_ = guide_.sample_posterior(
    jax.random.PRNGKey(24), params=run_results_.params, y=y, x=x, obs_mask=obs_mask
)

predictive = numpyro.infer.Predictive(
            model,
            params=run_results_.params,
            guide=guide_,
            num_samples=1000,
        )

start_idx = 50
predictive_samples = predictive(
    rng_key=jax.random.PRNGKey(24),
    y=y[-start_idx:],
    x=x[-start_idx:],
    obs_mask=obs_mask[-start_idx:],
)
fehiepsi commented 2 months ago

This is expected. obs_mask introduces a local latent variable named foo_unobserved whose distribution will be inferred by SVI. Assume that you have a model $x_n \to z_n \to y_n$ and you use autoguide to approximate $p(z_n | x_n, y_n)$. Such information does not allow you to make prediction $p(z'_n | x'_n)$. Instead, you might want to construct a custom guide for $q(z | x)$.

felipeangelimvieira commented 2 months ago

Oh I see, thank you for the explanation! I think I could use mask handler directly.