pzivich / zEpid

Epidemiology analysis package
http://zepid.readthedocs.org
MIT License
141 stars 33 forks source link

Saving DAGs programatically #164

Open joannadiong opened 2 years ago

joannadiong commented 2 years ago

I had corresponded with @pzivich over email and am posting our communication here for the benefit of other users.

JD.

Is it possible to program saving figures of directed acyclic graphs (DAGs) using zEpid? E.g. using the M-bias DAG code in the docs, typing plt.savefig('dag.png') only saves a blank PNG. To save it to disk, I'd need to plot the figure then manually click-and-point on the pop-up to save it.

PZ.

Unfortunately, saving the DAGs draw isn't too easy. In the background, I use NetworkX to organize and plot the diagram. NetworkX uses matplotlib, but it doesn't return the matplotlib axes object. So while you can tweak parts of the graph in various ways, NetworkX doesn't allow you to directly access the drawn part of the image. Normally, this isn't a problem but when it gets wrapped up in a class object that returns the matplotlib axes (which is what DirectedAcyclicGraph. draw_dag(...) does) it can lead to the issues you noted.

Currently, the best work-around is to generate the image by hand. Below is some code that should do the trick to match what is output by DirectedAcyclicGraph

import networkx as nx
import matplotlib.pyplot as plt
from zepid.causal.causalgraph import DirectedAcyclicGraph

dag = DirectedAcyclicGraph(exposure='X', outcome="Y")
dag.add_arrows((('X', 'Y'),
                ('U1', 'X'), ('U1', 'B'),
                ('U2', 'B'), ('U2', 'Y')
               ))

fig = plt.figure(figsize=(6, 5))
ax = plt.subplot(1, 1, 1)
positions = nx.spectral_layout(dag.dag)
nx.draw_networkx(dag.dag, positions, node_color="#d3d3d3", node_size=1000, edge_color='black',
                 linewidths=1.0, width=1.5, arrowsize=15, ax=ax, font_size=12)
plt.axis('off')
plt.savefig("filename.png", format='png', dpi=300)
plt.close()

Thanks Paul for the advice!

For longer term, it seems useful to build this or something similar into zEpid graphics to programatically save (complex) DAGs in Python for publication. Possibly using position values from DAGs generated in dagitty, which is handy to quickly graph and analyse complex DAGs. Just a thought.

Cheers

joannadiong commented 2 years ago

Note to self: dowhy, another causal package, now supports DAGitty syntax to generate causal graphs (see docs). I haven't looked under the hood yet, but their graphs are implemented a bit differently using the DOT graph format, not networkx.

zEpid graphs are pleasingly aesthetic though :) Would be a nice feature to support DAGitty graphs in future. Not sure if I'm skilled enough to build this in, but would like to look into it over time

pzivich commented 2 years ago

Thanks for looking into this!

As for compatibility with DAGitty, I think the format is like below. If that is provided as a string, I think it should be possible to parse? There are some items that aren't immediately clear to me in the format though. Like what does the positivity after A -> Z do differently than the declared positions for A and Z above? But I agree that this would be nice support to have

dag {
A [pos="-2.200,-1.520"]
B [pos="1.400,-1.460"]
D [outcome,pos="1.400,1.621"]
E [exposure,pos="-2.200,1.597"]
Z [pos="-0.300,-0.082"]
A -> E
A -> Z [pos="-0.791,-1.045"]
B -> D
B -> Z [pos="0.680,-0.496"]
E -> D
}
joannadiong commented 2 years ago

Cool find!: I just stumbled across the program mermaid that is a Java-based program that renders Markdown code into flowcharts and other diagrams. Looking at the flowchart features of nodes and edges, the syntax closely resembles syntax that DAGitty or Python would use to generate DAGs. A future workflow could look something like: generate and refine complex DAG with DAGitty -> copy-paste or programmatically extract DAG nodes and edges data to mermaid (Markdown), and touch up -> export mermaid code to SVG/PNG for publication, or simply use the Markdown render for day-to-day things like slides, notes, etc.

Checkout the Mermaid Github IO

Some advantages:

I will keep playing around and see how we could implement. See what you think!

joannadiong commented 2 years ago

From the code above:

import networkx as nx

dag

fig = plt.figure(figsize=(6, 5))
ax = plt.subplot(1, 1, 1)
positions = nx.spectral_layout(dag)
nx.draw_networkx(dag, positions, node_color="#d3d3d3", node_size=1000, edge_color='black', linewidths=1.0, width=1.5, arrowsize=15, ax=ax, font_size=12)
plt.axis('off')
plt.savefig("filename.png", format='png', dpi=300)
plt.close()

executing the line positions = nx.spectral_layout(dag) (or nx.drawing.layout.spectral_layout(dag)) produced an error:

TypeError: 'DirectedAcyclicGraph' object is not iterable

Would there be a work-around for this? Thank you

pzivich commented 2 years ago

Yes! if dag is a DirectedAcyclicGraph, you can replace dag in the above code with dag.dag. Internally, the NetworkX object is stored under the DirectedAcyclicGraph.dag parameter. So, you can use dag.dag to directly access the NetworkX graph and pass that to NetworkX functionalities.

I will fix my example above (where I don't clarify this!)

joannadiong commented 2 years ago

Interesting, thanks!

Did as you advised, and I managed to save the dag. This is a dag generated by DAGitty, and plotted using the output positions. But interestingly the orientation was preserved when the Networkx plot is shown in the console (panel A), but changed when saved to PNG (panel B): Screenshot from 2022-02-09 13-42-22

(Curve balls. Always... :) ) Not quite sure what the next steps are but I'll look into it over time. Any pointers would be appreciated!

pzivich commented 2 years ago

Could it be related to how the positions are determined? The above code from me is assigning positions via nx.spectral_layout(dag). Your A might be using different positions (like whatever NetworkX uses as the default)?

If you were to take DAGitty's positions (like A [pos="-2.200,-1.520"] B [pos="1.400,-1.460"]) and put those in a dictionary, like so

positions = {"A": [-2.200,-1.520], "B": [1.400,-1.460],  ...}

That should keep the same general look to the output as DAGitty. However, every node in the DAG would need to be assigned a 2D position in the dictionary (otherwise you will get an error).

Let me know if that helps

joannadiong commented 2 years ago

Thanks!

Yes, I had supplied dag the positions from DAGitty (as dict of lists) in dag.draw_dag(positions=pos). But it turns out that nx.spectral_layout(dag.dag) changes these positions:

original positions from DAGitty:
{'A': [-2.2, 1.52],
 'B': [1.4, 1.46],
 'D': [1.4, -1.621],
 'E': [-2.2, -1.597],
 'Z': [-0.3, 0.082]}

generated positions from nx.spectral_layout:
{'E': array([-0.70324807, -1.        ]),
 'D': array([-0.89346405, -0.30901699]),
 'A': array([ 0.45883284, -0.30901699]),
 'Z': array([0.98682236, 0.80901699]),
 'B': array([0.15105692, 0.80901699])}

So as you suggested, I supplied the dict of arrays to nx.draw_networkx instead, and the saved PNG and console plotted DAGs are identical. Problem solved!

joannadiong commented 2 years ago

From a previous comment on DAGitty positions,

A -> Z [pos="-0.791,-1.045"]

Like what does the positivity after A -> Z do differently than the declared positions for A and Z above?

it seems the details after the path curves the arrow (something like curving bullets?). Not essential, but could be something of interest in future.

joannadiong commented 2 years ago

Question. I've been building the plotting features in an external private repo, and wondering how to proceed with merging the functionality into zEpid. I felt it belongs better there, but the code is not pretty enough for a PR. It does the job, but is not in a package format with classes, tests, etc, and has not yet been tested on a complex and messy DAG. Also, you might have different preferences for things on style, approach, or others.

What might be a reasonable way to proceed? I could tidy things up then add you to the external repo?

Cheers

pzivich commented 2 years ago

Ahh! AFAIK there isn't an easy way to get curved arrows. However, this doesn't seem to be true anymore. But I don't know if that functionality can be used in an arrow-specific way (easily).

Let's start with an external repo (and you can share with me). We can talk about merging or keeping separate. Thanks :)