dmlc / dgl

Python package built to ease deep learning on graph, on top of existing DL frameworks.
http://dgl.ai
Apache License 2.0
13.37k stars 3k forks source link

Dynamic (?) graph supporting ?? #3270

Open Junyoungpark opened 3 years ago

Junyoungpark commented 3 years ago

Hi, I'm a big fan of DGL and always use it for my GNN projects.

I wonder what will be a friendly DGL practice for handling dynamic (?) graphs. It might be good to start with my own definition of a dynamic graph. Hereby, I mean dynamic graph is the graph whose edge connectivity can differ based on the node features.

Such a concept of the graph can be found quite frequently while building models for solving PDEs or simulating physics simulators. One particular case is Learning to Simulate Complex Physics with Graph Networks, where the graphs are required to be constructed when two particles (nodes) are close to each other.

I realized that, as far as I know, in DGL it will not be easy to implement the edges which are added/deleted based on the node features. So, what will be the most DGL way to implement the idea in this kind of scenario? I guess we can confront quite a similar issue when we handle the set data also.

My hack for implementing such graph computation was to construct a complete graph and mask the messages depending on the node features, which is not computationally efficient.

[1] Learning to Simulate Complex Physics with Graph Networks - http://proceedings.mlr.press/v119/sanchez-gonzalez20a/sanchez-gonzalez20a.pdf

BarclayII commented 3 years ago

For the case of changing the graph structure every time the node feature changes, another option is to construct a new graph every time the node feature changes. This is how we implement graph convolutions on point cloud with KNN graphs (e.g. in examples/pytorch/pointcloud/edgeconv), where we now do have efficient KNN graph construction algorithms.

The given paper builds a new graph with edges connecting nodes within a given radius. Currently we don't have a function for that though. If I were to implement it, I will probably compute pairwise distance using torch.cdist, threshold it, and build a new sparse graph from that. Do we have any algorithm for building such graphs more efficiently than the method I have? Also asking @lygztq since you implemented efficient KNNs so probably knows this better.

lygztq commented 3 years ago

I think the implementation of constructing such "radius graph" is similar to the current K-NN graph construction implementation. I can implement it in the next few weeks if you need.

BarclayII commented 3 years ago

I think the implementation of constructing such "radius graph" is similar to the current K-NN graph construction implementation. I can implement it in the next few weeks if you need.

That will be a godsend. You don't have to do it yourself - just telling us which part to change (in a high level) should be good enough.

lygztq commented 3 years ago

That will be a godsend. You don't have to do it yourself - just telling us which part to change (in a high level) should be good enough.

That will be nice :)

IMO, for almost all implementations, except the KD-Tree and NN-Descent, just replace the priority-queue for k-nearest points with a simple array (whose size is the max_num_neighbors) for points within the given radius. The KD-Tree library we use has a radius-search interface and the NN-Descent is a K-NN-specific algorithm.

BarclayII commented 3 years ago

just replace the priority-queue for k-nearest points with a simple array (whose size is the max_num_neighbors) for points within the given radius.

The priority queue is for brute force KNN right? If so then I would rather stick with the naive solution since that is vectorized and we need to compute every pairwise distance anyway (also friendlier to GPUs I guess). Also I cannot assume a maximum number of neighbors for building those graphs.

For the nanoflann KDTree, do I need to change the adapter? Or I can just call radiusSearch straight ahead?

lygztq commented 3 years ago

If you only try to implement the BLAS and KD-Tree version, there is no need for the max_num_neighbors, which is necessary for the BF implementation.

For the nanoflann KD-Tree, yes you do not have to change the adapter, just call the radiusSearch will be fine.

Junyoungpark commented 3 years ago

Thanks to all for sharing the idea and opening quite an intensive discussion on this issue.

@BarclayII, I tried another way for doing this. Roughly, the implementation constructs a complete graph at first. After, it computes the pairwise distance by calling the DGL graph's apply_edges method and builds dropping mask depending on the distance and thresholding radius, and lastly, drop the edges conditioned on the dropping mask.

Using torch.cdist might be an alternative option. However, I'm not sure which one is more efficient (1) finding the src, dst node indices via torch.where or (2) do the job on the DGL graph as explained. My major concerns are mainly peak memory usage and entire code execution times. @BarclayII have any idea?

BarclayII commented 3 years ago

@BarclayII, I tried another way for doing this. Roughly, the implementation constructs a complete graph at first. After, it computes the pairwise distance by calling the DGL graph's apply_edges method and builds dropping mask depending on the distance and thresholding radius, and lastly, drop the edges conditioned on the dropping mask.

Using torch.cdist might be an alternative option. However, I'm not sure which one is more efficient (1) finding the src, dst node indices via torch.where or (2) do the job on the DGL graph as explained. My major concerns are mainly peak memory usage and entire code execution times. @BarclayII have any idea?

In general building a graph on-the-fly should be more efficient than masking the graph, because the graph construction should have similar speed and memory consumption than masking (all the pairwise distances should be computed anyway), and message passing on the former case will have fewer edges than the latter case.

Junyoungpark commented 3 years ago

Oh, I might need to clarify the difference between the implementations. The first implementation is indeed less efficient as it anyway passes the messages on the complete graph. The second implementation option was using DGL apply_edges to drop some edges at the graph construction phase.

To implement the graph following your suggestion, it needs to find each edge is valid, and that will be done typically torch.where or torch.nonzero method. Since, from somewhere, I've heard the torch.where and torch.nonzero is extremely slow to the ones of numpy.

As far as I know, the dgl message passing frameworks are highly optimized. "so what about using the DGL message passing framework to find which edges are not valid." is the core of 2nd implementation.

BarclayII commented 3 years ago

The first implementation is indeed less efficient as it anyway passes the messages on the complete graph. The second implementation option was using DGL apply_edges to drop some edges at the graph construction phase.

To implement the graph following your suggestion, it needs to find each edge is valid, and that will be done typically torch.where or torch.nonzero method. Since, from somewhere, I've heard the torch.where and torch.nonzero is extremely slow to the ones of numpy.

If you want to construct a graph using apply_edges, then apply_edges can only compute the distance between the nodes, and you still need to use torch.where to find the non-zero entries, right?

Also I read that torch.nonzero should be faster than numpy in general: https://github.com/pytorch/pytorch/issues/14848

Junyoungpark commented 3 years ago

@BarclayII, I meant something like this.

Construct graph


def calc_dist(edges, r): 
      dist = (edges.src['coord'] - edges.dst['coord']).pow(2).sum(dim=-1).sqrt()
      mask = dist <= r
      return {'dist': dist, 'mask' : mask }

n_nodes = 10
coords = torch.tensor(n_nodes, 2)
g = dgl.graph(([0,0,0, ... , 9], [0,1,2, ... ])) # a complete graph
g.ndata['coord'] = coords
g.apply_edges(partial(calc_dist, r=0.8))
drop_eid = torch.arange(g.number_of_edges())[g.edata['mask']]
g.remove_edges(drop_eid)
BarclayII commented 3 years ago

I see. I would rather do something like:

coords = torch.randn(n_nodes, 2)
connections = torch.cdist(coords, coords) < radius
src, dst = connections.nonzero(as_tuple=True)
graph = dgl.graph((src, dst), num_nodes=n_nodes)

This should be more efficient than the one above since cdist should be more efficient than apply_edges due to dense computation, and nonzero() is also more efficient than remove_edges because the latter are hashtable lookups.

Junyoungpark commented 3 years ago

Thanks for helping me out! This discussion was a great chance to better understand DGL ops.

Since now this thread is considered to be "feature request". It might be even better to consider the option ignore_self for ignoring particle itself.

coords = torch.randn(n_nodes, 2)
connections = torch.cdist(coords, coords) < radius
if ignore_self:
    connections = connections.long() - torch.diag(torch.ones(n_nodes))
src, dst = connections.nonzero(as_tuple=True)
graph = dgl.graph((src, dst), num_nodes=n_nodes)

I believe such feature can be done like above.