Closed dbobrovskiy closed 11 months ago
Maybe I didn't get it right, but this produces the same result:
class Regression():
def __init__(self, x, y):
self.x = jnp.array(x)
self.y = jnp.array(y)
self.n = x.shape[0]
def model(self, y=None):
w = numpyro.sample("w", dist.Normal(0, 1))
sigma = numpyro.sample("sigma", dist.HalfCauchy(5))
with numpyro.plate("data", self.n):
y = numpyro.sample("y", dist.Normal(w * self.x, sigma), obs=y)
def fit_svi(self, n_iter=100):
self.guide = AutoNormal(self.model)
optimizer = optax.adam(0.01)
svi = SVI(self.model, self.guide, optimizer, loss=Trace_ELBO())
svi_state = svi.init(jax.random.PRNGKey(0), self.y)
for i in range(n_iter):
svi_state, loss = svi.update(svi_state, self.y)
self.svi_result = svi.get_params(svi_state)
--- same code in between ---
predict_svi_1 = Predictive(model=model_svi_1.model, guide=model_svi_1.guide, num_samples=100,
return_sites=["y", "w", "sigma"])(jax.random.PRNGKey(0), None)
predict_svi_2 = Predictive(model=model_svi_2.model, guide=model_svi_2.guide, num_samples=100,
return_sites=["y", "w", "sigma"])(jax.random.PRNGKey(0), None)
predict_svi_1["w"].mean(0).squeeze(), predict_svi_2["w"].mean(0).squeeze()
(Array(0.22423664, dtype=float32), Array(0.22423664, dtype=float32))
Sorry, I misunderstood the question. For SVI you need to specify params in Predictive. Otherwise, init params will be used.
Oh, I see, thank you so much!
numpyro.infer.Predictive
returns the same values independent of the model when running it on SVI results.Running
Predictive
in a similar way on the MCMC results, replacingguide
andnum_samples
withposterior_samples
, works perfectly fine. I am using the latest Numpyro stable version, 0.13.12.What could be the problem? Am I using SVI in Numpyro in a wrong way?