pyro-ppl / pyro

Deep universal probabilistic programming with Python and PyTorch
http://pyro.ai
Apache License 2.0
8.5k stars 982 forks source link

[Feature Request] 1) More details in `render_model` (esp. for `pyro.params`); 2) Returning dictionary from render_mode #3023

Open nipunbatra opened 2 years ago

nipunbatra commented 2 years ago

Hi, From the MLE-MAP tutorial, we have the following models

MLE model

def model_mle(data):
    f = pyro.param("latent_fairness", torch.tensor(0.5), 
                   constraint=constraints.unit_interval)
    with pyro.plate("data", data.size(0)):
        pyro.sample("obs", dist.Bernoulli(f), obs=data)

If we render this, we get something like the following image

Screenshot 2022-02-25 at 4 30 40 PM

MAP model

def original_model(data):
    f = pyro.sample("latent_fairness", dist.Beta(10.0, 10.0))
    with pyro.plate("data", data.size(0)):
        pyro.sample("obs", dist.Bernoulli(f), obs=data)

If we render this, we get something like the following image

Screenshot 2022-02-25 at 4 30 56 PM

Coin toss graphical model images from the popular Maths for ML book

This is from Figure 8.10

Screenshot 2022-02-25 at 4 34 45 PM

We'd expect our MLE model render to look like 8.10 b) and our MAP model to look like 8.10 c)

So, when we have latent_fairness as a parameter, it should perhaps just be written as latent_fairness and under the MAP model, it should be parameterised by the Beta distribution.

From the pyro render of the MLE model, it is not easily visible how observations are related to latent_fairness.

Feature Requests

So, I have two questions/requests

  1. Would it make sense to have pyro.params also show in renders. The difference in the renders between pyro.sample and pyro.parameter would be the associated distribution (and thus hyperparams) in pyro.sample
  2. Would it be possible to allow render to additionally return a dictionary when calling render_model? For example, once can then use that dictionary to create their own graphical models, for example using tikz-bayesnet. For example, the code below reproduces the Figure 8.10 from MML book shown above.3.

image

Click to toggle contents of `code` ```latex \documentclass[a4paper]{article} \usepackage{caption} \usepackage{subcaption} \usepackage{tikz} \usetikzlibrary{bayesnet} \usepackage{booktabs} \setlength{\tabcolsep}{12pt} \begin{document} \begin{figure}[ht] \begin{center} \begin{tabular}{@{}cccc@{}} \toprule $x_N$ explicit & Plate Notation & Hyperparameters on $\mu$ & Factor\\ \midrule & & & \\ \begin{tikzpicture} \node[obs] (x1) {$x_1$}; \node[const, right=0.5cm of x1] (dots) {$\cdots$}; \node[obs, right=0.5cm of dots] (xn) {$x_N$}; \node[latent, above=of dots] (mu) {$\mathbf{\mu}$}; \edge {mu} {x1,dots,xn} ; % \end{tikzpicture}& \begin{tikzpicture} \node[obs] (xn) {$x_n$}; \node[latent, above=of xn] (mu) {$\mathbf{\mu}$}; \plate{}{(xn)}{$n = 1, \cdots, N$}; \edge {mu} {xn} ; % \end{tikzpicture} & \begin{tikzpicture} \node[obs] (xn) {$x_n$}; \node[latent, above=of xn] (mu) {$\mathbf{\mu}$}; \node[const, right=0.5cm of mu] (beta) {$\mathbf{\beta}$}; \node[const, left=0.5cm of mu] (alpha) {$\mathbf{\alpha}$}; \plate{}{(xn)}{$n = 1, \cdots, N$}; \edge {mu} {xn} ; % \edge {alpha,beta} {mu} ; % \end{tikzpicture} & \begin{tikzpicture} \node[obs] (xn) {$x_n$}; \node[latent, above=of xn] (mu) {$\mathbf{\mu}$}; \factor[above=of xn] {y-f} {left:${Ber}$} {} {} ; % \node[const, above=1 of mu, xshift=0.5cm] (beta) {$\mathbf{\beta}$}; \node[const, above=1 of mu, xshift=-0.5cm] (alpha) {$\mathbf{\alpha}$}; \factor[above=of mu] {mu-f} {left:${Beta}$} {} {} ; % \plate{}{(xn)}{$n = 1, \cdots, N$}; \edge {mu} {xn} ; % \edge {alpha,beta} {mu-f} ; % \edge {mu-f}{mu} ; % \end{tikzpicture} \end{tabular} \end{center} \caption{Graphical models for a repeated Bernoulli experiment.} \end{figure} \end{document} ```
fritzo commented 2 years ago

Hi @nipunbatra, regarding

  1. I like the idea of adding a render_params argument to render_model and possibly printing their constraints if render_distributions == True. @fehiepsi might also like this feature in NumPyro.
  2. While render_model() already has a well-defined return value and thus should preserve that return value for backwards compatibility, you could instead get structural information from the internals of render_model(), namely using get_model_relations() and generate_graph_specification(): https://github.com/pyro-ppl/pyro/blob/f4fafc5c7fa0dc5a377ceb06ec59a234bf3ac465/pyro/infer/inspect.py#L494-L495 I'd recommend also checking out the other helper functions in pyro.infer.inspect, including get_dependencies().
fehiepsi commented 2 years ago

+1 for having an optional render_params.

nipunbatra commented 2 years ago

I think another thing to potentially consider while addressing this issue could be LaTeX support in renders. I believe Graphviz doesn't support, but Daft-PGM does.

An older issue https://github.com/pyro-ppl/pyro/issues/2980 mentioned the possibility of using Graphviz for layout and then plotting using Daft-PGM. Would increase overhead. Perhaps, could be left as an example for advanced users who run

relations = get_model_relations(model, model_args, model_kwargs) 
graph_spec = generate_graph_specification(relations) 
DAFT CODE(graph_spec)
karm-patel commented 2 years ago

Hi @fritzo and team, My name is Karm. I am working with Prof. @nipunbatra and lab colleague @patel-zeel. I made desired changes in the code and tested it on the following examples.

MLE Model 1

def model_mle_1(data):
    mu = pyro.param('mu', torch.tensor(0.),constraint=constraints.unit_interval)
    sd = pyro.param('sd', torch.tensor(1.),constraint=constraints.greater_than_eq(0))
    with pyro.plate('plate_data', len(data)):
        pyro.sample('obs', dist.Normal(mu, sd), obs=data)
data = torch.tensor([1.,2.,3.])
get_model_relations(model_mle_1,model_args=(data,))

{'sample_sample': {'obs': []}, 'sample_param': {'obs': ['sd', 'mu']}, 'sample_dist': {'obs': 'Normal'}, 'param_constraint': {'mu': Interval(lower_bound=0.0, upper_bound=1.0), 'sd': GreaterThanEq(lower_bound=0)}, 'plate_sample': {'plate_data': ['obs']}, 'observed': ['obs']}

render_model(model_mle_1,model_args=(data,),render_distributions=True)

image

render_model(model_mle_1,model_args=(data,),render_distributions=True,render_params=True)

image

MAP Model 1

def model_map_1(data):
    k1 = pyro.param('k1',torch.tensor(1.))
    mu = pyro.sample('mu', dist.Normal(0, k1))
    sd = pyro.sample('sd', dist.LogNormal(mu, k1))
    with pyro.plate('plate_data', len(data)):
        pyro.sample('obs', dist.Normal(mu, sd), obs=data)
data = torch.tensor([1.,2.,3.])
get_model_relations(model_map_1,model_args=(data,))

{'sample_sample': {'mu': [], 'sd': ['mu'], 'obs': ['sd', 'mu']}, 'sample_param': {'mu': ['k1'], 'sd': ['k1'], 'obs': []}, 'sample_dist': {'mu': 'Normal', 'sd': 'LogNormal', 'obs': 'Normal'}, 'param_constraint': {'k1': Real()}, 'plate_sample': {'plate_data': ['obs']}, 'observed': ['obs']}

render_model(model_map_1,model_args=(data,),render_distributions=True)

image

render_model(model_map_1,model_args=(data,),render_distributions=True,render_params=True)

image

MAP Model 2

def model_map_2(data):
    t = pyro.param('t',torch.tensor(1.),constraints.integer)
    a = pyro.sample('a', dist.Bernoulli(t))
    b =  pyro.param('b',torch.tensor(2.))
    with pyro.plate('plate_data', len(data)):
        pyro.sample('obs', dist.Beta(a, b), obs=data)
data = torch.tensor([1.,2.,3.])
get_model_relations(model_map_2,model_args=(data,))

{'sample_sample': {'mu': [], 'sd': ['mu'], 'obs': ['sd', 'mu']}, 'sample_param': {'mu': ['k1'], 'sd': ['k1'], 'obs': []}, 'sample_dist': {'mu': 'Normal', 'sd': 'LogNormal', 'obs': 'Normal'}, 'param_constraint': {'k1': Real()}, 'plate_sample': {'plate_data': ['obs']}, 'observed': ['obs']}

render_model(model_map_2,model_args=(data,),render_distributions=True)

image

render_model(model_map_2,model_args=(data,), render_distributions=True, render_params=True)

image

Changes made in code

Broadly, I made the following changes in pyro.infer.inspect.py.

  1. I added a key named sample_param in the dictionary returned by get_model_relation() to get a param that depends on a given sample. In method get_model_relations()I observed the output of trace.nodes, I found that there is no provenance tracking for params and I think without provenance tracking, we are not able to get dependent params. Since there is method named _pyro_post_sample() in class TrackProvenance(Messenger) which assigning provenance to sample. So I added a similar method for params named _pyro_post_param() in the same class. This method is called while getting the trace, trace = poutine.trace(model).get_trace(*model_args, **model_kwargs).

    def _pyro_post_param(self, msg):
    if msg["type"] == "param":
      provenance = frozenset({msg["name"]})  # track only direct dependencies
      value = detach_provenance(msg["value"])
      msg["value"] = ProvenanceTensor(value, provenance)

    Then, to add values in sample_param I followed a similar procedure as followed for adding values in sample_sample.

  2. I added another key named param_constraint to store constraints of params. This result will be required by the method generate_graph_specification().

  3. I added argument named render_params: bool = False in both methods render_model() and generate_graph_specification(). This argument will ensure optional output showing params in graph.

  4. In method generate_graph_specification(), dictionary node_data looks like below for sample variable,

    node_data[rv] = {
            "is_observed": .... ,
            "distribution": .... , 
        }

    I added an additional key constraint in node_data for param only, Note that following changes apply only when render_params = True.

    node_data[param] = {
            "is_observed": False ,
            "distribution":None ,
            "constraint": constraint
        }

    Further, edge_list and plate_groups will also be updated by adding params data.

  5. In the render_graph() method, I kept the shape of the param as plain and I added a code to show the constraint of params.

@fritzo, please give your feedback on this. Can I make PR If the dictionary and graph meet your expectations?

fritzo commented 2 years ago

@karm216 this looks great, we'd love PR contributing this feature! Note there are rigorous tests for pyro.infer.inspect, so we'll need to (1) update a bunch of those tests and (2) add some of your examples as new tests.