Closed AkiroSR closed 2 months ago
Sorry for the breakage! Could you try to use the dev branch of lightweight mmm? I will ping a dev there for a release if it works.
I think it's related to numpyro. The problem function is numpyro.deterministic. Everything else works. I'll have a look but I reckon it's related to the meridian release
Do you mean that pip install --upgrade git+https://github.com/google/lightweight_mmm.git
does not resolve the issue?
@fehiepsi saw your fix on lightweight Change-Id: I7c0658b0a13506c319fd3e6e00cdf2791d64e26f
.
I believe the long-term fix here is 2-fold:
If these are unfeasible for deeper reasons, then at least mention the pop trick here: https://num.pyro.ai/en/v0.2.0/utilities.html
As the current behavior is a bit counterintuitive.
I'm running into the same issue, here's a reproducible example:
import numpy as np
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS,Predictive
from jax import random
X = np.random.normal(0, 1, size=1000)
y = 5 + 1.2*X + np.random.normal(size=1000)
def model(X,y=None):
alpha = numpyro.sample("alpha", dist.Normal(0,10))
beta = numpyro.sample("beta", dist.Normal(0,1))
sigma = numpyro.sample("sigma", dist.Exponential(1))
with numpyro.plate("data", len(X)):
eta = numpyro.deterministic("eta", alpha + beta*X)
obs = numpyro.sample("obs", dist.Normal(eta, sigma), obs=y)
# Run NUTS.
kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000)
mcmc.run(random.PRNGKey(0), X=X, y=y)
# Make predictions where X is a different shape
posterior_samples = mcmc.get_samples()
# posterior_samples.pop("eta") # this fixes the issues
pred_func = Predictive(model, posterior_samples=posterior_samples)
```python
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
[... skipping hidden 1 frame]
File ~/Library/Caches/pypoetry/virtualenvs/ao-models-QNM_MrRk-py3.11/lib/python3.11/site-packages/jax/_src/util.py:290, in cache.
I get that inputting samples for a deterministic site would lead to the model expecting a certain shape, but it does seem a bit awkward that the typical workflow with predictions requires some extra work if deterministics are involved.
I wonder if something like this is possible? https://github.com/pyro-ppl/numpyro/blob/2f1bccdba2fc7b0a6ec235ca1bd5ce2417a0635c/numpyro/infer/mcmc.py#L714C61-L714C62
Hi @nikisix and @kylejcaron, really sorry for the breakage! I think a good action is to introduce exclude_deterministic=True
to Predictive. This rolls the behavior back to pre-0.14 release. I'm less worried that new users will want to use deterministic sites in Predictive. What do you think, @martinjankowiak?
something like that sounds reasonable. the change in behavior was probably a mistake...
@fehiepsi @martinjankowiak should the AutoGuide.sample_posterior()
be changed as well? It seems more difficult to fix since many sample_posterior
functions are unique to auto guides.
For example, the following workflow has the same problem :
guide = AutoNormal(model)
svi = SVI(model, guide, optim=numpyro.optim.Adam(0.01), loss=Trace_ELBO())
svi_result = svi.run(random.PRNGKey(0), 10000, X=X, y=y)
params = guide.sample_posterior(random.PRNGKey(0), params=svi_result.params)
pred_func = Predictive(model, params=params, num_samples=100)
preds = pred_func(random.PRNGKey(1), X=X[:250], y=None)
The solution for this seems to just including the guide and using SVI params instead, but I imagine some may be using the pattern above
pred_func = Predictive(model, guide=guide, params=svi_result.params, num_samples=100)
preds = pred_func(random.PRNGKey(1),X[:n_preds])['eta']
I think this pattern could be used with an exclude_deterministic
arg in AutoGuide's
@kylejcaron I think we can fix this in Predictive. The breakage happens because we allow substituting deterministic sites in the substitute handler. We can create a subclass of substitute
which skip processing deterministic sites and use it in Predictive.
@kylejcaron I think we can fix this in Predictive. The breakage happens because we allow substituting deterministic sites in the substitute handler. We can create a subclass of
substitute
which skip processing deterministic sites and use it in Predictive.
Got it that makes sense to me - seems like it'd involve just replacing the substitute call in this line and L987, but let me know if I'm missing anything.
I'm happy to make an attempt at this, any name recommendations for the new effect handler?
The substitute logic is at this line. You can change
substitute(model, data)
to something like
substitute(model, substitute_fn=lambda msg: data[msg["name"]] if msg["name"] in data and msg["type"] != "deterministic")
The substitute logic is at this line. You can change
substitute(model, data)
to something like
substitute(model, substitute_fn=lambda msg: data[msg["name"]] if msg["name"] in data and msg["type"] != "deterministic")
nice idea with the substitute_fn, just added a PR!
For some reason after fitting the model the numpyro.deterministic shape remains static, after trying to predict with a different shape it throws a shape error.
Example in lightweight-mmm:
This throws a size error, see: https://github.com/google/lightweight_mmm/issues/309 and https://github.com/google/lightweight_mmm/issues/308