Closed dbobrovskiy closed 12 months ago
@tare, thank you for implementing AutoGuideList!
However, I've encountered a problem with it when combined with plates.
import jax
import numpyro
import optax
import numpyro.distributions as dist
from numpyro import handlers
from numpyro.infer import SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoGuideList, AutoNormal, AutoDelta
from numpyro.optim import optax_to_numpyro
def model():
my_plate = numpyro.plate("my_plate", 10, dim=-1)
with my_plate:
numpyro.sample("z", dist.Normal(0, 1))
guide = AutoGuideList(model)
guide.append(AutoNormal(handlers.block(handlers.seed(model, rng_seed=0), hide=["z"])))
guide.append(AutoDelta(handlers.block(handlers.seed(model, rng_seed=1), expose=["z"])))
optim = optax_to_numpyro(optax.adam(0.01))
svi = SVI(model, guide, optim, loss=Trace_ELBO())
svi_state = svi.init(jax.random.PRNGKey(0))
svi_state, loss = svi.update(svi_state)
Running this code results in the following error (numpyro installed from source as of today, commit eaa29a0):
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
[/PlatesProblemDemo.ipynb](/PlatesProblemDemo.ipynb) Cell 3 line 1
[9](/PlatesProblemDemo.ipynb#X13sdnNjb2RlLXJlbW90ZQ%3D%3D?line=8) optim = optax_to_numpyro(optax.adam(0.01))
[10](/PlatesProblemDemo.ipynb#X13sdnNjb2RlLXJlbW90ZQ%3D%3D?line=9) svi = SVI(model, guide, optim, loss=Trace_ELBO())
---> [11](/PlatesProblemDemo.ipynb#X13sdnNjb2RlLXJlbW90ZQ%3D%3D?line=10) svi_state = svi.init(jax.random.PRNGKey(0))
[12](/PlatesProblemDemo.ipynb#X13sdnNjb2RlLXJlbW90ZQ%3D%3D?line=11) svi_state, loss = svi.update(svi_state)
File [~/numpyro/numpyro/infer/svi.py:184](~/numpyro/numpyro/infer/svi.py:184), in SVI.init(self, rng_key, init_params, *args, **kwargs)
182 if init_params is not None:
183 guide_init = substitute(guide_init, init_params)
--> 184 guide_trace = trace(guide_init).get_trace(*args, **kwargs, **self.static_kwargs)
185 init_guide_params = {
186 name: site["value"]
187 for name, site in guide_trace.items()
188 if site["type"] == "param"
189 }
190 if init_params is not None:
File [~/numpyro/numpyro/handlers.py:171](~/numpyro/numpyro/handlers.py:171), in trace.get_trace(self, *args, **kwargs)
163 def get_trace(self, *args, **kwargs):
164 """
165 Run the wrapped callable and return the recorded trace.
166
(...)
169 :return: `OrderedDict` containing the execution trace.
170 """
--> 171 self(*args, **kwargs)
172 return self.trace
File [~/numpyro/numpyro/primitives.py:105](~/numpyro/numpyro/primitives.py:105), in Messenger.__call__(self, *args, **kwargs)
103 return self
104 with self:
--> 105 return self.fn(*args, **kwargs)
File [~/numpyro/numpyro/primitives.py:105](~/numpyro/numpyro/primitives.py:105), in Messenger.__call__(self, *args, **kwargs)
103 return self
104 with self:
--> 105 return self.fn(*args, **kwargs)
File [~/numpyro/numpyro/infer/autoguide.py:280](~/numpyro/numpyro/infer/autoguide.py:280), in AutoGuideList.__call__(self, *args, **kwargs)
278 result = {}
279 for part in self._guides:
--> 280 result.update(part(*args, **kwargs))
281 return result
File [~/numpyro/numpyro/infer/autoguide.py:536](~/numpyro/numpyro/infer/autoguide.py:536), in AutoDelta.__call__(self, *args, **kwargs)
533 def __call__(self, *args, **kwargs):
534 if self.prototype_trace is None:
535 # run model to inspect the model structure
--> 536 self._setup_prototype(*args, **kwargs)
538 plates = self._create_plates(*args, **kwargs)
539 result = {}
File [~/numpyro/numpyro/infer/autoguide.py:526](~/numpyro/numpyro/infer/autoguide.py:526), in AutoDelta._setup_prototype(self, *args, **kwargs)
524 # If subsampling, repeat init_value to full size.
525 for frame in site["cond_indep_stack"]:
--> 526 full_size = self._prototype_frame_full_sizes[frame.name]
527 if full_size != frame.size:
528 dim = frame.dim - event_dim
KeyError: 'my_plate'
How should I use AutoGuideList
when I need to sample inside a plate?
I think the issue might be that my_plate
isn't exposed. Thus please try the following:
import jax
import numpyro
import optax
import numpyro.distributions as dist
from numpyro import handlers
from numpyro.infer import SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoGuideList, AutoNormal, AutoDelta
from numpyro.optim import optax_to_numpyro
def model():
my_plate = numpyro.plate("my_plate", 10, dim=-1)
with my_plate:
numpyro.sample("z", dist.Normal(0, 1))
guide = AutoGuideList(model)
guide.append(AutoNormal(handlers.block(handlers.seed(model, rng_seed=0), hide=["z"])))
guide.append(AutoDelta(handlers.block(handlers.seed(model, rng_seed=1), expose=["z", "my_plate"])))
optim = optax_to_numpyro(optax.adam(0.01))
svi = SVI(model, guide, optim, loss=Trace_ELBO())
svi_state = svi.init(jax.random.PRNGKey(0))
svi_state, loss = svi.update(svi_state)
does its job, thank you! sure I should have thought about it, but in Pyro it was done automatically so I got stuck
It would be great if AutoGuideList (https://docs.pyro.ai/en/stable/infer.autoguide.html#autoguidelist) was added to numpyro.
I've seen this discussed on a forum (https://forum.pyro.ai/t/using-a-combination-of-autoguides-for-a-single-model/5328), and I would really appreciate this feature in Numpyro, so more of my models can be easily translated from Pyro. Until it is implemented, it would be cool to have somewhere in the tutorials an example of how different AutoGuides for different variables can currently be combined.