pyro-ppl / numpyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
https://num.pyro.ai
Apache License 2.0
2.1k stars 227 forks source link

`AutoNormal`, `AutoDelta`, and `AutoGuideList` do not support subsamples of variable size. #1739

Open tillahoffmann opened 5 months ago

tillahoffmann commented 5 months ago

AutoNormal, AutoDelta, and AutoGuideList raise an exception in SVI when the subsample size varies across different log_density evaluation. Here is an example reproducing the issue (run on master).

import numpyro
from jax import numpy as jnp

def model(n, x=None, subsample_size=None):
    mu = numpyro.sample("mu", numpyro.distributions.Normal())
    with numpyro.plate("n", n, subsample_size=subsample_size):
        numpyro.sample("x", numpyro.distributions.Normal(mu, 1), obs=x)

def demo(guide_cls):
    n = 10
    x_obs = jnp.zeros(n)
    guide = guide_cls(model)

    with numpyro.handlers.seed(rng_seed=0):
        # Initialize the guide with the full dataset, get a trace, and replay against
        # the model.
        guide(n, x_obs)
        guide_trace = numpyro.handlers.trace(guide).get_trace()
        replayed = numpyro.handlers.replay(model, guide_trace)

        print("evaluate log density for full data")
        numpyro.infer.util.log_density(replayed, (n, x_obs), {}, {})

        print("evaluate log density for subsampled data")
        numpyro.infer.util.log_density(replayed, (n, x_obs[:3], 3), {}, {})

        print("done")

# This works just fine.
demo(numpyro.infer.autoguide.AutoDiagonalNormal)
# This raises an error (see traceback below).
demo(numpyro.infer.autoguide.AutoNormal)

The traceback for the failed call is as follows.

evaluate log density for full data
evaluate log density for subsampled data
.../numpyro/playground/test.py:7: UserWarning: subsample_size does not match len(subsample), 3 vs 10. Did you accidentally use different subsample_size in the model and guide?
  with numpyro.plate("n", n, subsample_size=subsample_size):
Traceback (most recent call last):
  File ".../numpyro/venv/lib/python3.10/site-packages/jax/_src/lax/lax.py", line 151, in broadcast_shapes
    return _broadcast_shapes_cached(*shapes)
  File ".../numpyro/venv/lib/python3.10/site-packages/jax/_src/util.py", line 287, in wrapper
    return cached(config.config._trace_context(), *args, **kwargs)
  File ".../numpyro/venv/lib/python3.10/site-packages/jax/_src/util.py", line 280, in cached
    return f(*args, **kwargs)
  File ".../numpyro/venv/lib/python3.10/site-packages/jax/_src/lax/lax.py", line 157, in _broadcast_shapes_cached
    return _broadcast_shapes_uncached(*shapes)
  File ".../numpyro/venv/lib/python3.10/site-packages/jax/_src/lax/lax.py", line 173, in _broadcast_shapes_uncached
    raise ValueError(f"Incompatible shapes for broadcasting: shapes={list(shapes)}")
ValueError: Incompatible shapes for broadcasting: shapes=[(3,), (10,)]

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File ".../numpyro/numpyro/infer/util.py", line 80, in log_density
    broadcast_shapes(guide_shape, model_shape)
  File ".../numpyro/venv/lib/python3.10/site-packages/jax/_src/lax/lax.py", line 153, in broadcast_shapes
    return _broadcast_shapes_uncached(*shapes)
  File ".../numpyro/venv/lib/python3.10/site-packages/jax/_src/lax/lax.py", line 173, in _broadcast_shapes_uncached
    raise ValueError(f"Incompatible shapes for broadcasting: shapes={list(shapes)}")
ValueError: Incompatible shapes for broadcasting: shapes=[(3,), (10,)]

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File ".../numpyro/playground/test.py", line 32, in <module>
    demo(numpyro.infer.autoguide.AutoNormal)
  File ".../numpyro/playground/test.py", line 27, in demo
    numpyro.infer.util.log_density(replayed, (n, x_obs[:3], 3), {}, {})
  File ".../numpyro/numpyro/infer/util.py", line 82, in log_density
    raise ValueError(
ValueError: Model and guide shapes disagree at site: 'x': (10,) vs (3,)

I think the issue is that these guides use _create_plates which in turn uses prototype traces to determine the subsample size.

https://github.com/pyro-ppl/numpyro/blob/aec6bd58b4cf0c2d81b96e62d4d0cf7af3744885/numpyro/infer/autoguide.py#L108-L113

The prototype traces are of course only created on the first invocation such that there is a discrepancy in the expected subsample size when a different mini-batch size is used. Guides inheriting from AutoContinuous do not call _create_plates and do not use plates in their __call__ method. I couldn't quite figure out why some guides do and some guides don't.

fehiepsi commented 5 months ago

This is a good point. I guess a better check is to make sure that there are no latent variables under the subsample plates. When that is the case, there is no need to specify the create_plates argument.

fehiepsi commented 2 months ago

@tillahoffmann sorry for the last misleading comment. For subsampling, the usage is

create_plates = lambda n, x, subsample_size=None: numpyro.plate("n", n, subsample_size=subsample_size)
AutoNormal(..., create_plates=create_plates)