xgi-org / xgi

CompleX Group Interactions (XGI) is a Python package for higher-order networks.
https://xgi.readthedocs.io
Other
180 stars 28 forks source link

Returning node positions from `xgi.draw` #510

Closed tlarock closed 5 months ago

tlarock commented 7 months ago

I want to use the same node positions to draw a few different hypergraphs on the same set of nodes. My lazy approach was to draw the first hypergraph with xgi.draw, then get node positions from the return and use them for the subsequent draw calls. However, it was surprisingly difficult to get the positions, despite the fact that xgi.draw computes positions and returns something called node_collection. The difficulty is that the return object is a matplotlib.PathCollection, so it doesn't contain pos in the same semantics (at least as far as I can tell). One might be able to use PathCollection.get_offsets() or similar to get them, I'm not sure.

This is not urgent at all because really I should compute the layout first, then use the same positions throughout, which is also what the example in the documentation points towards. Not trying to encourage bad design, but it still might be nice to return pos explicitly for convenience. I think the main question is whether adding another value to the return will somehow break other pieces of code or be otherwise inconvenient.

May relate to #280 in the future.

maximelucas commented 7 months ago

Thanks Tim! Yea my first thought is exactly what you wrote: if you need to reuse the positions, it's best to pre-compute them outside of draw and then pass them as an argument. We even have random seed for layout functions, so you could even have the same positions across different scripts.

About potentially returning positions, we can certainly think about it. I'm not sure about it because we are already returning many things (1 axis and 3 collections). We could make it not break things for sure. Let's talk with the others.

nwlandry commented 6 months ago

I'm against returning positions. I think that it will unnecessarily clutter the code. Maybe we can make a recipe for this? I will note that the code corresponding to my recent paper does exactly what is described here.

nwlandry commented 5 months ago

What about this recipe lifted from the example I mentioned? My worry is that it is too big for a recipe.


import xgi
import matplotlib.pyplot as plt
from matplotlib import gridspec
import numpy as np
import itertools

links = [[1, 2], [1, 3], [5, 6], [1, 7]]
triangles = [[3, 5, 7], [2, 7, 1], [6, 10, 15]]
squares = [[7, 8, 9, 10]]
pentagons = [[1, 11, 12, 13, 14]]
edges = links + triangles + squares + pentagons

H = xgi.Hypergraph(edges)
pos = xgi.barycenter_spring_layout(H, seed=2)

link_color = "#000000"
triangle_color = "#648FFF"
square_color = "#785EF0"
pentagon_color = "#DC267F"
colors = [link_color, triangle_color, square_color, pentagon_color]

def color_edges(H):
    return [colors[i - 2] for i in H.edges.filterby("order", 1, "gt").size.aslist()]

H = xgi.Hypergraph(edges)

filtering_parameters = np.arange(
    H.edges.size.min(), H.edges.size.max() + 1, 1, dtype=int
)

uniform_filtering = [
    xgi.subhypergraph(H, edges=H.edges.filterby("size", k, "eq")).copy()
    for k in filtering_parameters
]
geq_filtering = [
    xgi.subhypergraph(H, edges=H.edges.filterby("size", k, "geq")).copy()
    for k in filtering_parameters
]
leq_filtering = [
    xgi.subhypergraph(H, edges=H.edges.filterby("size", k, "leq")).copy()
    for k in filtering_parameters
]
exclusion_filtering = [
    xgi.subhypergraph(H, edges=H.edges.filterby("size", k, "neq")).copy()
    for k in filtering_parameters
]
filterings = [uniform_filtering, geq_filtering, leq_filtering, exclusion_filtering]

pos = xgi.pca_transform(xgi.pairwise_spring_layout(H, seed=3))

fig = plt.figure(layout="constrained", figsize=(8, 4))

gs_leftright = gridspec.GridSpec(1, 3, figure=fig, wspace=0.075)

gs_panels = gridspec.GridSpecFromSubplotSpec(4, 4, subplot_spec=gs_leftright[1:])

ax_left = fig.add_subplot(gs_leftright[0])
xgi.draw(
    H, pos=pos, ax=ax_left, edge_fc=color_edges(H), node_size=7, node_lw=0.5, dyad_lw=0.75, alpha=1
)

labels = [r"$H_{(=, k)}$", r"$H_{(\geq, k)}$", r"$H_{(\leq, k)}$", r"$H_{(\neq, k)}$"]

for i, j in itertools.product(range(4), repeat=2):
    ax = fig.add_subplot(gs_panels[i, j])
    ec = color_edges(filterings[i][j])
    xgi.draw(
        filterings[i][j],
        pos=pos,
        ax=ax,
        node_size=4,
        dyad_lw=0.75,
        node_lw=0.5,
        edge_fc=ec,
        alpha=1,
    )
    if i == 0:
        ax.set_title(rf"$k={j + 2}$")

    if j == 0:
        ax.text(-3.5, 0, labels[i], fontsize=16)
plt.show()