pymc-devs / pymc

Bayesian Modeling and Probabilistic Programming in Python
https://docs.pymc.io/
Other
8.47k stars 1.97k forks source link

Allow customizing style of model_graph nodes #7302

Closed wd60622 closed 1 month ago

wd60622 commented 1 month ago

Description

Allow users to override the default graphviz by passing mapping from node type to function that creates the node kwargs.

The default behavior is the same but now the user can override based on the node type defined:

User needs to define function from node variable to kwargs passed to graphviz / networkx

def node_formatter(var: TensorVariable) -> dict[str, Any]:
    return {"label": var.name}

then pass a mapping from node type to node_formatter.

node_mapping = {"Data": node_formatter}
pm.model_to_graphviz(model, node_formatters=node_mapping)

Example

import pymc as pm 

with pm.Model() as model: 
    a = pm.Normal("a")
    b = pm.Normal("b", mu=a)
    c = pm.Normal("c", mu=a)
    d = pm.Normal("d", mu=b + c, observed=0)

default = pm.model_to_graphviz(model)

simple_random_variable_formatter = {
    "Free Random Variable": lambda var: {"shape": "circle", "label": var.name}, 
}
pm.model_to_graphviz(model, node_formatters=simple_random_variable_formatter)

fancy_formatter = {
    "Free Random Variable": lambda var: {"shape": "polygon", "sides": "7", "label": var.name, "style": "dashed"}, 
    "Observed Random Variable": lambda var: {"shape": "circle", "label": var.name, "style": "solid"},
}
pm.model_to_graphviz(model, node_formatters=fancy_formatter)

Default: default

Simple: simple

Fancy:

polygon

Related Issue

Checklist

Type of change


📚 Documentation preview 📚: https://pymc--7302.org.readthedocs.build/en/7302/

wd60622 commented 1 month ago

@drbenvincent might be interested in this example

with pm.Model() as model: 
    c = pm.Normal("c")
    z = pm.Normal("z", mu=c)
    y = pm.Normal("y", mu=c + z)

node_formatters = {
    "Basic Random Variable": lambda var: {"shape": "circle", "label": var.name},
}
model_graph.model_to_graphviz(model, node_formatters=node_formatters)

simple

wd60622 commented 1 month ago

Thank you for the review @ricardoV94 Will add these changes and add some tests

wd60622 commented 1 month ago

Some random / unrelated tests seem to be randomly failing. Latest run passed