Qiskit / rustworkx

A high performance Python graph library implemented in Rust.
https://www.rustworkx.org
Apache License 2.0
1.08k stars 148 forks source link

`mpl_draw()` does not work for multigraphs #774

Closed DionTimmermann closed 3 months ago

DionTimmermann commented 1 year ago

Information

What is the current behavior?

When plotting a multigraph with parallel edges using rustworkx.visualization.mpl_draw(), two issues occur:

  1. The arrows for multiple edges are printed on top of each other.
  2. Only one label is drawn for every pair of nodes. (The labels are currently saved in a dict, which uses node pair as the key.)

What is the expected behavior?

I see three possible solutions:

Steps to reproduce the problem

  1. Create a graph with two parallel edges:
    graph = rustworkx.PyDiGraph()
    graph.add_node('A')
    graph.add_node('B')
    graph.add_edge(0, 1, 1)
    graph.add_edge(0, 1, 2)
    mpl_draw(graph, with_labels=True, labels=str, edge_labels=str, alpha=0.5)
  2. Save the image as a vector graphic and inspect it.

The image will contain two arrows on top of each other and the label for edge 2. The label for edge 1 is not drawn.

IvanIsCoding commented 1 year ago

As a temporary workaround, you might want to try rustworkx.visualization.graphviz_draw which does support multigraphs.

But indeed, we ported the visualization code from NetworkX. So we inherited the problems when drawing multigraphs.

maxwell04-wq commented 4 months ago

As I worked on a rustworkx-related PR in last year's UnitaryHack, I'd like to work on this issue.

maxwell04-wq commented 4 months ago

This is the graph I have obtained after a few edits in the matplotlib.py file:

test_fig_success_1

The code used to recreate this is:

import rustworkx
from rustworkx.visualization import mpl_draw
import matplotlib.pyplot as plt

graph = rustworkx.PyDiGraph()
graph.add_node('A')
graph.add_node('B')
graph.add_node('C')

graph.add_edge(1, 2, 4)
graph.add_edge(1, 0, 2)
graph.add_edge(0, 1, 3)

fig = mpl_draw(graph, with_labels=True, labels=str, edge_labels=str, alpha=0.5)
plt.savefig('test_fig.png')

Before moving forward, I wish to clarify a few points:

  1. The sample code shared to recreate the problem has two edges from 0 to 1, and plotting this only creates one node from 0 to 1. Is this the expected behavior?
  2. Should I assert the connectionstyle to either arc or arc3 for this plotting style? Also, is it okay to keep rad constant? I've used rad=0.25.
  3. It appears that one of the labels of the loop was disappearing because of being plotted over by the other label. Offsetting the labels has solved the problem and there is no need to explicitly change kwds["edge_labels"].
maxwell04-wq commented 4 months ago

Given that UnitaryHack concludes in less than a week, can someone please review my PR for any further changes required?