pyro-ppl / pyro

Deep universal probabilistic programming with Python and PyTorch
http://pyro.ai
Apache License 2.0
8.58k stars 987 forks source link

Rendering PyroModules can fail with local parameter mode enabled #3365

Closed eb8680 closed 6 months ago

eb8680 commented 6 months ago

When using module_local_params=True, calling pyro.render_model on a PyroModule with constrained parameters can fail with a KeyError, as shown by the test case in this stack trace:

_________________________________________________________________________________________________ test_render_constrained_param[True] __________________________________________________________________________________________________

use_module_local_params = True

    @pytest.mark.parametrize("use_module_local_params", [True, False])
    def test_render_constrained_param(use_module_local_params):

        class Model(PyroModule):

            @PyroParam(constraint=constraints.positive)
            def x(self):
                return torch.tensor(1.234)

            @PyroParam(constraint=constraints.real)
            def y(self):
                return torch.tensor(0.456)

            def forward(self):
                return self.x + self.y

        with pyro.settings.context(module_local_params=use_module_local_params):
            model = Model()
>           pyro.render_model(model)

tests/nn/test_module.py:1068: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
pyro/infer/inspect.py:630: in render_model
    get_model_relations(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

model = Model(), model_args = (), model_kwargs = {}, include_deterministic = False

    def get_model_relations(
        model: Callable,
        model_args: Optional[tuple] = None,
        model_kwargs: Optional[dict] = None,
        include_deterministic: bool = False,
    ):
        """
        Infer relations of RVs and plates from given model and optionally data.
        See https://github.com/pyro-ppl/pyro/issues/949 for more details.

        This returns a dictionary with keys:

        -  "sample_sample" map each downstream sample site to a list of the upstream
           sample sites on which it depend;
        -  "sample_dist" maps each sample site to the name of the distribution at
           that site;
        -  "plate_sample" maps each plate name to a list of the sample sites within
           that plate; and
        -  "observe" is a list of observed sample sites.

        For example for the model::

            def model(data):
                m = pyro.sample('m', dist.Normal(0, 1))
                sd = pyro.sample('sd', dist.LogNormal(m, 1))
                with pyro.plate('N', len(data)):
                    pyro.sample('obs', dist.Normal(m, sd), obs=data)

        the relation is::

            {'sample_sample': {'m': [], 'sd': ['m'], 'obs': ['m', 'sd']},
             'sample_dist': {'m': 'Normal', 'sd': 'LogNormal', 'obs': 'Normal'},
             'plate_sample': {'N': ['obs']},
             'observed': ['obs']}

        :param callable model: A model to inspect.
        :param model_args: Optional tuple of model args.
        :param model_kwargs: Optional dict of model kwargs.
        :param bool include_deterministic: Whether to include deterministic sites.
        :rtype: dict
        """
        if model_args is None:
            model_args = ()
        if model_kwargs is None:
            model_kwargs = {}
        assert isinstance(model_args, tuple)
        assert isinstance(model_kwargs, dict)

        with torch.random.fork_rng(), torch.no_grad(), pyro.validation_enabled(False):
            with TrackProvenance(include_deterministic=include_deterministic):
                trace = poutine.trace(model).get_trace(*model_args, **model_kwargs)

        sample_sample = {}
        sample_param = {}
        sample_dist = {}
        param_constraint = {}
        plate_sample = defaultdict(list)
        observed = []

        def _get_type_from_frozenname(frozen_name):
            return trace.nodes[frozen_name]["type"]

        for name, site in trace.nodes.items():
            if site["type"] == "param":
>               param_constraint[name] = str(site["kwargs"]["constraint"])
E               KeyError: 'constraint'

pyro/infer/inspect.py:316: KeyError