pyro-ppl / numpyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
https://num.pyro.ai
Apache License 2.0
2.18k stars 238 forks source link

Predictive not working for SVI #1676

Closed dbobrovskiy closed 11 months ago

dbobrovskiy commented 11 months ago

numpyro.infer.Predictive returns the same values independent of the model when running it on SVI results.

import numpy as np
import jax
import jax.numpy as jnp
import optax
import numpyro
import numpyro.distributions as dist
from numpyro.infer import SVI, Trace_ELBO, MCMC, NUTS, Predictive
from numpyro.infer.autoguide import AutoNormal

class Regression():
    def __init__(self, x, y):
        self.x = jnp.array(x)
        self.y = jnp.array(y)
        self.n = x.shape[0]

    def model(self):
        w = numpyro.sample("w", dist.Normal(0, 1))
        sigma = numpyro.sample("sigma", dist.HalfCauchy(5))
        with numpyro.plate("data", self.n):
            y = numpyro.sample("y", dist.Normal(w * self.x, sigma), obs=self.y)

    def fit_svi(self, n_iter=100):
        self.guide = AutoNormal(self.model)
        optimizer = optax.adam(0.01)
        svi = SVI(self.model, self.guide, optimizer, loss=Trace_ELBO())
        svi_state = svi.init(jax.random.PRNGKey(0))
        for i in range(n_iter):
            svi_state, loss = svi.update(svi_state)
        self.svi_result = svi.get_params(svi_state)

X = np.random.normal(size=(100, 1))
beta1, beta2 = np.random.normal(size=(1, 1)), np.random.normal(size=(1, 1))
Y1 = X @ beta1 + np.random.normal(size=(100, 1))
Y2 = X @ beta2 + np.random.normal(size=(100, 1))

beta1, beta2

(array([[-0.58427902]]), array([[-0.1831061]]))

model_svi_1 = Regression(X.ravel(), Y1.ravel())
model_svi_2 = Regression(X.ravel(), Y2.ravel())
model_svi_1.fit_svi()
model_svi_2.fit_svi()

model_svi_1.svi_result["w_auto_loc"], model_svi_2.svi_result["w_auto_loc"]

(Array(-0.5395397, dtype=float32), Array(-0.3524328, dtype=float32))

predict_svi_1 = Predictive(model=model_svi_1.model, guide=model_svi_1.guide, num_samples=100,
                     return_sites=["y", "w", "sigma"])(jax.random.PRNGKey(0))
predict_svi_2 = Predictive(model=model_svi_2.model, guide=model_svi_2.guide, num_samples=100,
                     return_sites=["y", "w", "sigma"])(jax.random.PRNGKey(0))

predict_svi_1["w"].mean(0).squeeze(), predict_svi_2["w"].mean(0).squeeze()

(Array(0.22423664, dtype=float32), Array(0.22423664, dtype=float32))

Running Predictive in a similar way on the MCMC results, replacing guide and num_samples with posterior_samples, works perfectly fine. I am using the latest Numpyro stable version, 0.13.12.

What could be the problem? Am I using SVI in Numpyro in a wrong way?

fehiepsi commented 11 months ago

You need to set obs to None. Pls see the example in the SVI docs.

dbobrovskiy commented 11 months ago

Maybe I didn't get it right, but this produces the same result:

class Regression():
    def __init__(self, x, y):
        self.x = jnp.array(x)
        self.y = jnp.array(y)
        self.n = x.shape[0]

    def model(self, y=None):
        w = numpyro.sample("w", dist.Normal(0, 1))
        sigma = numpyro.sample("sigma", dist.HalfCauchy(5))
        with numpyro.plate("data", self.n):
            y = numpyro.sample("y", dist.Normal(w * self.x, sigma), obs=y)

    def fit_svi(self, n_iter=100):
        self.guide = AutoNormal(self.model)
        optimizer = optax.adam(0.01)
        svi = SVI(self.model, self.guide, optimizer, loss=Trace_ELBO())
        svi_state = svi.init(jax.random.PRNGKey(0), self.y)
        for i in range(n_iter):
            svi_state, loss = svi.update(svi_state, self.y)
        self.svi_result = svi.get_params(svi_state)

--- same code in between ---

predict_svi_1 = Predictive(model=model_svi_1.model, guide=model_svi_1.guide, num_samples=100,
                     return_sites=["y", "w", "sigma"])(jax.random.PRNGKey(0), None)
predict_svi_2 = Predictive(model=model_svi_2.model, guide=model_svi_2.guide, num_samples=100,
                     return_sites=["y", "w", "sigma"])(jax.random.PRNGKey(0), None)

predict_svi_1["w"].mean(0).squeeze(), predict_svi_2["w"].mean(0).squeeze()

(Array(0.22423664, dtype=float32), Array(0.22423664, dtype=float32))

fehiepsi commented 11 months ago

Sorry, I misunderstood the question. For SVI you need to specify params in Predictive. Otherwise, init params will be used.

dbobrovskiy commented 11 months ago

Oh, I see, thank you so much!