pyro-ppl / numpyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
https://num.pyro.ai
Apache License 2.0
2.1k stars 227 forks source link

numpyro.deterministic static on infer.Predictive #1772

Closed AkiroSR closed 2 months ago

AkiroSR commented 3 months ago

For some reason after fitting the model the numpyro.deterministic shape remains static, after trying to predict with a different shape it throws a shape error.

Example in lightweight-mmm:

# extra_features.shape = (10,3) / trying to predict 10 new time periods

extra_features_effect = numpyro.deterministic(
    name="extra_features_effect",
    value=jnp.einsum(
        extra_features_einsum, extra_features, coef_extra_features
    ),
)

# extra_features_effect.shape = (30,3)  / output is resized to the size of the model as when fit; 30 periods

This throws a size error, see: https://github.com/google/lightweight_mmm/issues/309 and https://github.com/google/lightweight_mmm/issues/308

fehiepsi commented 3 months ago

Sorry for the breakage! Could you try to use the dev branch of lightweight mmm? I will ping a dev there for a release if it works.

AkiroSR commented 3 months ago

I think it's related to numpyro. The problem function is numpyro.deterministic. Everything else works. I'll have a look but I reckon it's related to the meridian release

fehiepsi commented 3 months ago

Do you mean that pip install --upgrade git+https://github.com/google/lightweight_mmm.git does not resolve the issue?

nikisix commented 2 months ago

@fehiepsi saw your fix on lightweight Change-Id: I7c0658b0a13506c319fd3e6e00cdf2791d64e26f.

I believe the long-term fix here is 2-fold:

  1. Return deterministic sites in posterior_samples (mcmc saves deterministic sites in its samples, and accessed via mcmc.get_samples()).
  2. Predictive always pops deterministic sites.

If these are unfeasible for deeper reasons, then at least mention the pop trick here: https://num.pyro.ai/en/v0.2.0/utilities.html

As the current behavior is a bit counterintuitive.

kylejcaron commented 2 months ago

I'm running into the same issue, here's a reproducible example:

import numpy as np
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS,Predictive
from jax import random

X = np.random.normal(0, 1, size=1000)
y = 5 + 1.2*X + np.random.normal(size=1000)

def model(X,y=None):
    alpha = numpyro.sample("alpha", dist.Normal(0,10))
    beta = numpyro.sample("beta", dist.Normal(0,1))
    sigma = numpyro.sample("sigma", dist.Exponential(1))
    with numpyro.plate("data", len(X)):
        eta = numpyro.deterministic("eta", alpha + beta*X)
        obs = numpyro.sample("obs", dist.Normal(eta, sigma), obs=y)

# Run NUTS.
kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000)
mcmc.run(random.PRNGKey(0), X=X, y=y)

# Make predictions where X is a different shape
posterior_samples = mcmc.get_samples()
# posterior_samples.pop("eta") # this fixes the issues
pred_func = Predictive(model, posterior_samples=posterior_samples)
traceback

```python --------------------------------------------------------------------------- ValueError Traceback (most recent call last) [... skipping hidden 1 frame] File ~/Library/Caches/pypoetry/virtualenvs/ao-models-QNM_MrRk-py3.11/lib/python3.11/site-packages/jax/_src/util.py:290, in cache..wrap..wrapper(*args, **kwargs) 289 else: --> 290 return cached(config.trace_context(), *args, **kwargs) File ~/Library/Caches/pypoetry/virtualenvs/ao-models-QNM_MrRk-py3.11/lib/python3.11/site-packages/jax/_src/util.py:283, in cache..wrap..cached(_, *args, **kwargs) 281 @functools.lru_cache(max_size) 282 def cached(_, *args, **kwargs): --> 283 return f(*args, **kwargs) File ~/Library/Caches/pypoetry/virtualenvs/ao-models-QNM_MrRk-py3.11/lib/python3.11/site-packages/jax/_src/lax/lax.py:155, in _broadcast_shapes_cached(*shapes) 153 @cache() 154 def _broadcast_shapes_cached(*shapes: tuple[int, ...]) -> tuple[int, ...]: --> 155 return _broadcast_shapes_uncached(*shapes) File ~/Library/Caches/pypoetry/virtualenvs/ao-models-QNM_MrRk-py3.11/lib/python3.11/site-packages/jax/_src/lax/lax.py:171, in _broadcast_shapes_uncached(*shapes) 170 if result_shape is None: --> 171 raise ValueError(f"Incompatible shapes for broadcasting: shapes={list(shapes)}") 172 return result_shape ValueError: Incompatible shapes for broadcasting: shapes=[(200,), (1000,)] During handling of the above exception, another exception occurred: ValueError Traceback (most recent call last) Cell In[1], line 26 24 # Make predictions where X is a different shape 25 pred_func = Predictive(model, posterior_samples=mcmc.get_samples()) ---> 26 preds = pred_func(random.PRNGKey(1), X=X[:200], y=None) File ~/Library/Caches/pypoetry/virtualenvs/ao-models-QNM_MrRk-py3.11/lib/python3.11/site-packages/numpyro/infer/util.py:1011, in Predictive.__call__(self, rng_key, *args, **kwargs) 1001 """ 1002 Returns dict of samples from the predictive distribution. By default, only sample sites not 1003 contained in `posterior_samples` are returned. This can be modified by changing the (...) 1008 :param kwargs: model kwargs. 1009 """ 1010 if self.batch_ndims == 0 or self.params == {} or self.guide is None: -> 1011 return self._call_with_params(rng_key, self.params, args, kwargs) 1012 elif self.batch_ndims == 1: # batch over parameters 1013 batch_size = jnp.shape(tree_flatten(self.params)[0][0])[0] File ~/Library/Caches/pypoetry/virtualenvs/ao-models-QNM_MrRk-py3.11/lib/python3.11/site-packages/numpyro/infer/util.py:988, in Predictive._call_with_params(self, rng_key, params, args, kwargs) 977 posterior_samples = _predictive( 978 guide_rng_key, 979 guide, (...) 985 model_kwargs=kwargs, 986 ) 987 model = substitute(self.model, self.params) --> 988 return _predictive( 989 rng_key, 990 model, 991 posterior_samples, 992 self._batch_shape, 993 return_sites=self.return_sites, 994 infer_discrete=self.infer_discrete, 995 parallel=self.parallel, 996 model_args=args, 997 model_kwargs=kwargs, 998 ) File ~/Library/Caches/pypoetry/virtualenvs/ao-models-QNM_MrRk-py3.11/lib/python3.11/site-packages/numpyro/infer/util.py:825, in _predictive(rng_key, model, posterior_samples, batch_shape, return_sites, infer_discrete, parallel, model_args, model_kwargs) 823 rng_key = rng_key.reshape(batch_shape + key_shape) 824 chunk_size = num_samples if parallel else 1 --> 825 return soft_vmap( 826 single_prediction, (rng_key, posterior_samples), len(batch_shape), chunk_size 827 ) File ~/Library/Caches/pypoetry/virtualenvs/ao-models-QNM_MrRk-py3.11/lib/python3.11/site-packages/numpyro/util.py:419, in soft_vmap(fn, xs, batch_ndims, chunk_size) 413 xs = tree_map( 414 lambda x: jnp.reshape(x, prepend_shape + (chunk_size,) + jnp.shape(x)[1:]), 415 xs, 416 ) 417 fn = vmap(fn) --> 419 ys = lax.map(fn, xs) if num_chunks > 1 else fn(xs) 420 map_ndims = int(num_chunks > 1) + int(chunk_size > 1) 421 ys = tree_map( 422 lambda y: jnp.reshape( 423 y, (int(np.prod(jnp.shape(y)[:map_ndims])),) + jnp.shape(y)[map_ndims:] 424 )[:batch_size], 425 ys, 426 ) [... skipping hidden 12 frame] File ~/Library/Caches/pypoetry/virtualenvs/ao-models-QNM_MrRk-py3.11/lib/python3.11/site-packages/numpyro/infer/util.py:798, in _predictive..single_prediction(val) 789 pred_samples = _sample_posterior( 790 config_enumerate(condition(model, samples)), 791 first_available_dim, (...) 795 **model_kwargs, 796 ) 797 else: --> 798 model_trace = trace( 799 seed(substitute(masked_model, samples), rng_key) 800 ).get_trace(*model_args, **model_kwargs) 801 pred_samples = {name: site["value"] for name, site in model_trace.items()} 803 if return_sites is not None: File ~/Library/Caches/pypoetry/virtualenvs/ao-models-QNM_MrRk-py3.11/lib/python3.11/site-packages/numpyro/handlers.py:171, in trace.get_trace(self, *args, **kwargs) 163 def get_trace(self, *args, **kwargs): 164 """ 165 Run the wrapped callable and return the recorded trace. 166 (...) 169 :return: `OrderedDict` containing the execution trace. 170 """ --> 171 self(*args, **kwargs) 172 return self.trace File ~/Library/Caches/pypoetry/virtualenvs/ao-models-QNM_MrRk-py3.11/lib/python3.11/site-packages/numpyro/primitives.py:105, in Messenger.__call__(self, *args, **kwargs) 103 return self 104 with self: --> 105 return self.fn(*args, **kwargs) File ~/Library/Caches/pypoetry/virtualenvs/ao-models-QNM_MrRk-py3.11/lib/python3.11/site-packages/numpyro/primitives.py:105, in Messenger.__call__(self, *args, **kwargs) 103 return self 104 with self: --> 105 return self.fn(*args, **kwargs) [... skipping similar frames: Messenger.__call__ at line 105 (2 times)] File ~/Library/Caches/pypoetry/virtualenvs/ao-models-QNM_MrRk-py3.11/lib/python3.11/site-packages/numpyro/primitives.py:105, in Messenger.__call__(self, *args, **kwargs) 103 return self 104 with self: --> 105 return self.fn(*args, **kwargs) Cell In[1], line 17, in model(X, y) 15 with numpyro.plate("data", len(X)): 16 eta = numpyro.deterministic("eta", alpha + beta*X) ---> 17 obs = numpyro.sample("obs", dist.Normal(eta, sigma), obs=y) File ~/Library/Caches/pypoetry/virtualenvs/ao-models-QNM_MrRk-py3.11/lib/python3.11/site-packages/numpyro/primitives.py:222, in sample(name, fn, obs, rng_key, sample_shape, infer, obs_mask) 207 initial_msg = { 208 "type": "sample", 209 "name": name, (...) 218 "infer": {} if infer is None else infer, 219 } 221 # ...and use apply_stack to send it to the Messengers --> 222 msg = apply_stack(initial_msg) 223 return msg["value"] File ~/Library/Caches/pypoetry/virtualenvs/ao-models-QNM_MrRk-py3.11/lib/python3.11/site-packages/numpyro/primitives.py:47, in apply_stack(msg) 45 pointer = 0 46 for pointer, handler in enumerate(reversed(_PYRO_STACK)): ---> 47 handler.process_message(msg) 48 # When a Messenger sets the "stop" field of a message, 49 # it prevents any Messengers above it on the stack from being applied. 50 if msg.get("stop"): File ~/Library/Caches/pypoetry/virtualenvs/ao-models-QNM_MrRk-py3.11/lib/python3.11/site-packages/numpyro/primitives.py:546, in plate.process_message(self, msg) 544 overlap_idx = max(len(expected_shape) - len(dist_batch_shape), 0) 545 trailing_shape = expected_shape[overlap_idx:] --> 546 broadcast_shape = lax.broadcast_shapes( 547 trailing_shape, tuple(dist_batch_shape) 548 ) 549 batch_shape = expected_shape[:overlap_idx] + broadcast_shape 550 msg["fn"] = msg["fn"].expand(batch_shape) [... skipping hidden 1 frame] File ~/Library/Caches/pypoetry/virtualenvs/ao-models-QNM_MrRk-py3.11/lib/python3.11/site-packages/jax/_src/lax/lax.py:171, in _broadcast_shapes_uncached(*shapes) 169 result_shape = _try_broadcast_shapes(shape_list) 170 if result_shape is None: --> 171 raise ValueError(f"Incompatible shapes for broadcasting: shapes={list(shapes)}") 172 return result_shape ValueError: Incompatible shapes for broadcasting: shapes=[(200,), (1000,)] ```

I get that inputting samples for a deterministic site would lead to the model expecting a certain shape, but it does seem a bit awkward that the typical workflow with predictions requires some extra work if deterministics are involved.

I wonder if something like this is possible? https://github.com/pyro-ppl/numpyro/blob/2f1bccdba2fc7b0a6ec235ca1bd5ce2417a0635c/numpyro/infer/mcmc.py#L714C61-L714C62

fehiepsi commented 2 months ago

Hi @nikisix and @kylejcaron, really sorry for the breakage! I think a good action is to introduce exclude_deterministic=True to Predictive. This rolls the behavior back to pre-0.14 release. I'm less worried that new users will want to use deterministic sites in Predictive. What do you think, @martinjankowiak?

martinjankowiak commented 2 months ago

something like that sounds reasonable. the change in behavior was probably a mistake...

kylejcaron commented 2 months ago

@fehiepsi @martinjankowiak should the AutoGuide.sample_posterior() be changed as well? It seems more difficult to fix since many sample_posterior functions are unique to auto guides.

For example, the following workflow has the same problem :

guide = AutoNormal(model) 
svi = SVI(model, guide, optim=numpyro.optim.Adam(0.01), loss=Trace_ELBO())
svi_result = svi.run(random.PRNGKey(0), 10000, X=X, y=y)

params = guide.sample_posterior(random.PRNGKey(0), params=svi_result.params)
pred_func = Predictive(model, params=params, num_samples=100)
preds = pred_func(random.PRNGKey(1), X=X[:250], y=None)

The solution for this seems to just including the guide and using SVI params instead, but I imagine some may be using the pattern above

pred_func = Predictive(model, guide=guide, params=svi_result.params, num_samples=100)
preds = pred_func(random.PRNGKey(1),X[:n_preds])['eta']
kylejcaron commented 2 months ago

I think this pattern could be used with an exclude_deterministic arg in AutoGuide's

fehiepsi commented 2 months ago

@kylejcaron I think we can fix this in Predictive. The breakage happens because we allow substituting deterministic sites in the substitute handler. We can create a subclass of substitute which skip processing deterministic sites and use it in Predictive.

kylejcaron commented 2 months ago

@kylejcaron I think we can fix this in Predictive. The breakage happens because we allow substituting deterministic sites in the substitute handler. We can create a subclass of substitute which skip processing deterministic sites and use it in Predictive.

Got it that makes sense to me - seems like it'd involve just replacing the substitute call in this line and L987, but let me know if I'm missing anything.

I'm happy to make an attempt at this, any name recommendations for the new effect handler?

fehiepsi commented 2 months ago

The substitute logic is at this line. You can change

substitute(model, data)

to something like

substitute(model, substitute_fn=lambda msg: data[msg["name"]] if msg["name"] in data and msg["type"] != "deterministic")
kylejcaron commented 2 months ago

The substitute logic is at this line. You can change

substitute(model, data)

to something like

substitute(model, substitute_fn=lambda msg: data[msg["name"]] if msg["name"] in data and msg["type"] != "deterministic")

nice idea with the substitute_fn, just added a PR!