dmlc / dgl

Python package built to ease deep learning on graph, on top of existing DL frameworks.
http://dgl.ai
Apache License 2.0
13.18k stars 2.99k forks source link

heterograph.set_batch_num_edges could run automatically if batch_num_nodes is set #7498

Open anlutfi opened 3 days ago

anlutfi commented 3 days ago

Since, in a graph batch(GB), there are no edges between individual graphs, once batch_num_nodes is set for GB, automatically calculate batch_num_edges by calling set_bath_num_edges with no arguments, or something like set_bath_num_edges(auto=True)

Motivation

I needed to make subgraphs of a GB, and I have to maintain batch info consistency. I figured out that, after calculating batch_num_nodes for the subgraph, the corresponding edges are all the edges that have source and destination in the same group in batch_num_nodes.

Example: graph g is a new subgraph of a GB that has batch_num_nodes = [100, 100, 100]. To get the node ids for each individual graph in g, we perform cumulative sum (CS), such as CS = [100, 200, 300]. Nodes with indices < 100 are in the first graph, 100 <= indices < 200 are in the second, and 200 <= indices < 300 are in the third.

with these indices in hand, and the certainty that are no edges between nodes of different graphs in a batch, one can simply look at the source and dest of edges to determine to which batch they belong. So batch_num_edges comes for free.

I believe this feature is a good QOL improvement as it removes one source of user error when calculating batch_num_edges by hand.

Code that I'm using

bnn = sg.batch_num_nodes()

e_tail = torch.cumsum(bnn, dim=0) - 1
e_head = torch.cat([torch.tensor([0]).to(e_tail.device), e_tail[:-1] + 1])

source, dest = sg.edges()
source = source.unsqueeze(1).tile((1, len(e_tail)))
dest = dest.unsqueeze(1).tile((1, len(e_tail)))
mask = (source >= e_head) & (source <= e_tail) & (dest >= e_head) & (dest <= e_tail)
bne = torch.count_nonzero(mask, dim=0)
sg.set_batch_num_edges(bne)