pyro-ppl / numpyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
https://num.pyro.ai
Apache License 2.0
2.18k stars 238 forks source link

Request for AutoGuideList in numpyro #1638

Closed dbobrovskiy closed 12 months ago

dbobrovskiy commented 1 year ago

It would be great if AutoGuideList (https://docs.pyro.ai/en/stable/infer.autoguide.html#autoguidelist) was added to numpyro.

I've seen this discussed on a forum (https://forum.pyro.ai/t/using-a-combination-of-autoguides-for-a-single-model/5328), and I would really appreciate this feature in Numpyro, so more of my models can be easily translated from Pyro. Until it is implemented, it would be cool to have somewhere in the tutorials an example of how different AutoGuides for different variables can currently be combined.

dbobrovskiy commented 12 months ago

@tare, thank you for implementing AutoGuideList!

However, I've encountered a problem with it when combined with plates.

import jax
import numpyro
import optax
import numpyro.distributions as dist
from numpyro import handlers
from numpyro.infer import SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoGuideList, AutoNormal, AutoDelta
from numpyro.optim import optax_to_numpyro

def model():
    my_plate = numpyro.plate("my_plate", 10, dim=-1)
    with my_plate:
        numpyro.sample("z", dist.Normal(0, 1))

guide = AutoGuideList(model)
guide.append(AutoNormal(handlers.block(handlers.seed(model, rng_seed=0), hide=["z"])))
guide.append(AutoDelta(handlers.block(handlers.seed(model, rng_seed=1), expose=["z"])))
optim = optax_to_numpyro(optax.adam(0.01))
svi = SVI(model, guide, optim, loss=Trace_ELBO())
svi_state = svi.init(jax.random.PRNGKey(0))
svi_state, loss = svi.update(svi_state)

Running this code results in the following error (numpyro installed from source as of today, commit eaa29a0):

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
[/PlatesProblemDemo.ipynb](/PlatesProblemDemo.ipynb) Cell 3 line 1
      [9](/PlatesProblemDemo.ipynb#X13sdnNjb2RlLXJlbW90ZQ%3D%3D?line=8) optim = optax_to_numpyro(optax.adam(0.01))
     [10](/PlatesProblemDemo.ipynb#X13sdnNjb2RlLXJlbW90ZQ%3D%3D?line=9) svi = SVI(model, guide, optim, loss=Trace_ELBO())
---> [11](/PlatesProblemDemo.ipynb#X13sdnNjb2RlLXJlbW90ZQ%3D%3D?line=10) svi_state = svi.init(jax.random.PRNGKey(0))
     [12](/PlatesProblemDemo.ipynb#X13sdnNjb2RlLXJlbW90ZQ%3D%3D?line=11) svi_state, loss = svi.update(svi_state)

File [~/numpyro/numpyro/infer/svi.py:184](~/numpyro/numpyro/infer/svi.py:184), in SVI.init(self, rng_key, init_params, *args, **kwargs)
    182 if init_params is not None:
    183     guide_init = substitute(guide_init, init_params)
--> 184 guide_trace = trace(guide_init).get_trace(*args, **kwargs, **self.static_kwargs)
    185 init_guide_params = {
    186     name: site["value"]
    187     for name, site in guide_trace.items()
    188     if site["type"] == "param"
    189 }
    190 if init_params is not None:

File [~/numpyro/numpyro/handlers.py:171](~/numpyro/numpyro/handlers.py:171), in trace.get_trace(self, *args, **kwargs)
    163 def get_trace(self, *args, **kwargs):
    164     """
    165     Run the wrapped callable and return the recorded trace.
    166 
   (...)
    169     :return: `OrderedDict` containing the execution trace.
    170     """
--> 171     self(*args, **kwargs)
    172     return self.trace

File [~/numpyro/numpyro/primitives.py:105](~/numpyro/numpyro/primitives.py:105), in Messenger.__call__(self, *args, **kwargs)
    103     return self
    104 with self:
--> 105     return self.fn(*args, **kwargs)

File [~/numpyro/numpyro/primitives.py:105](~/numpyro/numpyro/primitives.py:105), in Messenger.__call__(self, *args, **kwargs)
    103     return self
    104 with self:
--> 105     return self.fn(*args, **kwargs)

File [~/numpyro/numpyro/infer/autoguide.py:280](~/numpyro/numpyro/infer/autoguide.py:280), in AutoGuideList.__call__(self, *args, **kwargs)
    278 result = {}
    279 for part in self._guides:
--> 280     result.update(part(*args, **kwargs))
    281 return result

File [~/numpyro/numpyro/infer/autoguide.py:536](~/numpyro/numpyro/infer/autoguide.py:536), in AutoDelta.__call__(self, *args, **kwargs)
    533 def __call__(self, *args, **kwargs):
    534     if self.prototype_trace is None:
    535         # run model to inspect the model structure
--> 536         self._setup_prototype(*args, **kwargs)
    538     plates = self._create_plates(*args, **kwargs)
    539     result = {}

File [~/numpyro/numpyro/infer/autoguide.py:526](~/numpyro/numpyro/infer/autoguide.py:526), in AutoDelta._setup_prototype(self, *args, **kwargs)
    524 # If subsampling, repeat init_value to full size.
    525 for frame in site["cond_indep_stack"]:
--> 526     full_size = self._prototype_frame_full_sizes[frame.name]
    527     if full_size != frame.size:
    528         dim = frame.dim - event_dim

KeyError: 'my_plate'

How should I use AutoGuideList when I need to sample inside a plate?

tare commented 12 months ago

I think the issue might be that my_plate isn't exposed. Thus please try the following:

import jax
import numpyro
import optax
import numpyro.distributions as dist
from numpyro import handlers
from numpyro.infer import SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoGuideList, AutoNormal, AutoDelta
from numpyro.optim import optax_to_numpyro

def model():
    my_plate = numpyro.plate("my_plate", 10, dim=-1)
    with my_plate:
        numpyro.sample("z", dist.Normal(0, 1))

guide = AutoGuideList(model)
guide.append(AutoNormal(handlers.block(handlers.seed(model, rng_seed=0), hide=["z"])))
guide.append(AutoDelta(handlers.block(handlers.seed(model, rng_seed=1), expose=["z", "my_plate"])))
optim = optax_to_numpyro(optax.adam(0.01))
svi = SVI(model, guide, optim, loss=Trace_ELBO())
svi_state = svi.init(jax.random.PRNGKey(0))
svi_state, loss = svi.update(svi_state)
dbobrovskiy commented 12 months ago

does its job, thank you! sure I should have thought about it, but in Pyro it was done automatically so I got stuck