When using module_local_params=True, calling pyro.render_model on a PyroModule with constrained parameters can fail with a KeyError, as shown by the test case in this stack trace:
_________________________________________________________________________________________________ test_render_constrained_param[True] __________________________________________________________________________________________________
use_module_local_params = True
@pytest.mark.parametrize("use_module_local_params", [True, False])
def test_render_constrained_param(use_module_local_params):
class Model(PyroModule):
@PyroParam(constraint=constraints.positive)
def x(self):
return torch.tensor(1.234)
@PyroParam(constraint=constraints.real)
def y(self):
return torch.tensor(0.456)
def forward(self):
return self.x + self.y
with pyro.settings.context(module_local_params=use_module_local_params):
model = Model()
> pyro.render_model(model)
tests/nn/test_module.py:1068:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
pyro/infer/inspect.py:630: in render_model
get_model_relations(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
model = Model(), model_args = (), model_kwargs = {}, include_deterministic = False
def get_model_relations(
model: Callable,
model_args: Optional[tuple] = None,
model_kwargs: Optional[dict] = None,
include_deterministic: bool = False,
):
"""
Infer relations of RVs and plates from given model and optionally data.
See https://github.com/pyro-ppl/pyro/issues/949 for more details.
This returns a dictionary with keys:
- "sample_sample" map each downstream sample site to a list of the upstream
sample sites on which it depend;
- "sample_dist" maps each sample site to the name of the distribution at
that site;
- "plate_sample" maps each plate name to a list of the sample sites within
that plate; and
- "observe" is a list of observed sample sites.
For example for the model::
def model(data):
m = pyro.sample('m', dist.Normal(0, 1))
sd = pyro.sample('sd', dist.LogNormal(m, 1))
with pyro.plate('N', len(data)):
pyro.sample('obs', dist.Normal(m, sd), obs=data)
the relation is::
{'sample_sample': {'m': [], 'sd': ['m'], 'obs': ['m', 'sd']},
'sample_dist': {'m': 'Normal', 'sd': 'LogNormal', 'obs': 'Normal'},
'plate_sample': {'N': ['obs']},
'observed': ['obs']}
:param callable model: A model to inspect.
:param model_args: Optional tuple of model args.
:param model_kwargs: Optional dict of model kwargs.
:param bool include_deterministic: Whether to include deterministic sites.
:rtype: dict
"""
if model_args is None:
model_args = ()
if model_kwargs is None:
model_kwargs = {}
assert isinstance(model_args, tuple)
assert isinstance(model_kwargs, dict)
with torch.random.fork_rng(), torch.no_grad(), pyro.validation_enabled(False):
with TrackProvenance(include_deterministic=include_deterministic):
trace = poutine.trace(model).get_trace(*model_args, **model_kwargs)
sample_sample = {}
sample_param = {}
sample_dist = {}
param_constraint = {}
plate_sample = defaultdict(list)
observed = []
def _get_type_from_frozenname(frozen_name):
return trace.nodes[frozen_name]["type"]
for name, site in trace.nodes.items():
if site["type"] == "param":
> param_constraint[name] = str(site["kwargs"]["constraint"])
E KeyError: 'constraint'
pyro/infer/inspect.py:316: KeyError
When using
module_local_params=True
, callingpyro.render_model
on aPyroModule
with constrained parameters can fail with aKeyError
, as shown by the test case in this stack trace: