Closed wd60622 closed 1 month ago
@drbenvincent might be interested in this example
with pm.Model() as model:
c = pm.Normal("c")
z = pm.Normal("z", mu=c)
y = pm.Normal("y", mu=c + z)
node_formatters = {
"Basic Random Variable": lambda var: {"shape": "circle", "label": var.name},
}
model_graph.model_to_graphviz(model, node_formatters=node_formatters)
Thank you for the review @ricardoV94 Will add these changes and add some tests
Some random / unrelated tests seem to be randomly failing. Latest run passed
Description
Allow users to override the default graphviz by passing mapping from node type to function that creates the node kwargs.
The default behavior is the same but now the user can override based on the node type defined:
User needs to define function from node variable to kwargs passed to graphviz / networkx
then pass a mapping from node type to node_formatter.
Example
Default:![default](https://github.com/pymc-devs/pymc/assets/57733339/0d79bb8c-30f5-4900-8e70-2053a182ca32)
Simple:![simple](https://github.com/pymc-devs/pymc/assets/57733339/79795349-848a-4b88-a94b-2d839e42b61f)
Fancy:
Related Issue
Checklist
Type of change
📚 Documentation preview 📚: https://pymc--7302.org.readthedocs.build/en/7302/