Open Junyoungpark opened 3 years ago
For the case of changing the graph structure every time the node feature changes, another option is to construct a new graph every time the node feature changes. This is how we implement graph convolutions on point cloud with KNN graphs (e.g. in examples/pytorch/pointcloud/edgeconv
), where we now do have efficient KNN graph construction algorithms.
The given paper builds a new graph with edges connecting nodes within a given radius. Currently we don't have a function for that though. If I were to implement it, I will probably compute pairwise distance using torch.cdist
, threshold it, and build a new sparse graph from that. Do we have any algorithm for building such graphs more efficiently than the method I have? Also asking @lygztq since you implemented efficient KNNs so probably knows this better.
I think the implementation of constructing such "radius graph" is similar to the current K-NN graph construction implementation. I can implement it in the next few weeks if you need.
I think the implementation of constructing such "radius graph" is similar to the current K-NN graph construction implementation. I can implement it in the next few weeks if you need.
That will be a godsend. You don't have to do it yourself - just telling us which part to change (in a high level) should be good enough.
That will be a godsend. You don't have to do it yourself - just telling us which part to change (in a high level) should be good enough.
That will be nice :)
IMO, for almost all implementations, except the KD-Tree and NN-Descent, just replace the priority-queue for k-nearest points with a simple array (whose size is the max_num_neighbors
) for points within the given radius. The KD-Tree library we use has a radius-search interface and the NN-Descent is a K-NN-specific algorithm.
just replace the priority-queue for k-nearest points with a simple array (whose size is the
max_num_neighbors
) for points within the given radius.
The priority queue is for brute force KNN right? If so then I would rather stick with the naive solution since that is vectorized and we need to compute every pairwise distance anyway (also friendlier to GPUs I guess). Also I cannot assume a maximum number of neighbors for building those graphs.
For the nanoflann KDTree, do I need to change the adapter? Or I can just call radiusSearch
straight ahead?
If you only try to implement the BLAS and KD-Tree version, there is no need for the max_num_neighbors
, which is necessary for the BF implementation.
For the nanoflann KD-Tree, yes you do not have to change the adapter, just call the radiusSearch
will be fine.
Thanks to all for sharing the idea and opening quite an intensive discussion on this issue.
@BarclayII, I tried another way for doing this. Roughly, the implementation constructs a complete graph at first. After, it computes the pairwise distance by calling the DGL graph's apply_edges
method and builds dropping mask
depending on the distance and thresholding radius, and lastly, drop the edges conditioned on the dropping mask.
Using torch.cdist
might be an alternative option. However, I'm not sure which one is more efficient (1) finding the src, dst node indices via torch.where
or (2) do the job on the DGL graph as explained. My major concerns are mainly peak memory usage and entire code execution times. @BarclayII have any idea?
@BarclayII, I tried another way for doing this. Roughly, the implementation constructs a complete graph at first. After, it computes the pairwise distance by calling the DGL graph's
apply_edges
method and buildsdropping mask
depending on the distance and thresholding radius, and lastly, drop the edges conditioned on thedropping mask.
Using
torch.cdist
might be an alternative option. However, I'm not sure which one is more efficient (1) finding the src, dst node indices viatorch.where
or (2) do the job on the DGL graph as explained. My major concerns are mainly peak memory usage and entire code execution times. @BarclayII have any idea?
In general building a graph on-the-fly should be more efficient than masking the graph, because the graph construction should have similar speed and memory consumption than masking (all the pairwise distances should be computed anyway), and message passing on the former case will have fewer edges than the latter case.
Oh, I might need to clarify the difference between the implementations. The first implementation is indeed less efficient as it anyway passes the messages on the complete graph. The second implementation option was using DGL apply_edges
to drop some edges at the graph construction phase.
To implement the graph following your suggestion, it needs to find each edge is valid, and that will be done typically torch.where
or torch.nonzero
method. Since, from somewhere, I've heard the torch.where
and torch.nonzero
is extremely slow to the ones of numpy.
As far as I know, the dgl message passing frameworks are highly optimized. "so what about using the DGL message passing framework to find which edges are not valid." is the core of 2nd implementation.
The first implementation is indeed less efficient as it anyway passes the messages on the complete graph. The second implementation option was using DGL
apply_edges
to drop some edges at the graph construction phase.To implement the graph following your suggestion, it needs to find each edge is valid, and that will be done typically
torch.where
ortorch.nonzero
method. Since, from somewhere, I've heard thetorch.where
andtorch.nonzero
is extremely slow to the ones of numpy.
If you want to construct a graph using apply_edges
, then apply_edges
can only compute the distance between the nodes, and you still need to use torch.where
to find the non-zero entries, right?
Also I read that torch.nonzero
should be faster than numpy in general: https://github.com/pytorch/pytorch/issues/14848
@BarclayII, I meant something like this.
def calc_dist(edges, r):
dist = (edges.src['coord'] - edges.dst['coord']).pow(2).sum(dim=-1).sqrt()
mask = dist <= r
return {'dist': dist, 'mask' : mask }
n_nodes = 10
coords = torch.tensor(n_nodes, 2)
g = dgl.graph(([0,0,0, ... , 9], [0,1,2, ... ])) # a complete graph
g.ndata['coord'] = coords
g.apply_edges(partial(calc_dist, r=0.8))
drop_eid = torch.arange(g.number_of_edges())[g.edata['mask']]
g.remove_edges(drop_eid)
I see. I would rather do something like:
coords = torch.randn(n_nodes, 2)
connections = torch.cdist(coords, coords) < radius
src, dst = connections.nonzero(as_tuple=True)
graph = dgl.graph((src, dst), num_nodes=n_nodes)
This should be more efficient than the one above since cdist
should be more efficient than apply_edges
due to dense computation, and nonzero()
is also more efficient than remove_edges
because the latter are hashtable lookups.
Thanks for helping me out! This discussion was a great chance to better understand DGL ops.
Since now this thread is considered to be "feature request". It might be even better to consider the option ignore_self
for ignoring particle itself.
coords = torch.randn(n_nodes, 2)
connections = torch.cdist(coords, coords) < radius
if ignore_self:
connections = connections.long() - torch.diag(torch.ones(n_nodes))
src, dst = connections.nonzero(as_tuple=True)
graph = dgl.graph((src, dst), num_nodes=n_nodes)
I believe such feature can be done like above.
Hi, I'm a big fan of
DGL
and always use it for my GNN projects.I wonder what will be a friendly DGL practice for handling dynamic (?) graphs. It might be good to start with my own definition of a dynamic graph. Hereby, I mean dynamic graph is the graph whose edge connectivity can differ based on the node features.
Such a concept of the graph can be found quite frequently while building models for solving PDEs or simulating physics simulators. One particular case is
Learning to Simulate Complex Physics with Graph Networks
, where the graphs are required to be constructed when two particles (nodes) are close to each other.I realized that, as far as I know, in DGL it will not be easy to implement the edges which are added/deleted based on the node features. So, what will be the most DGL way to implement the idea in this kind of scenario? I guess we can confront quite a similar issue when we handle the set data also.
My hack for implementing such graph computation was to construct a complete graph and mask the messages depending on the node features, which is not computationally efficient.
[1] Learning to Simulate Complex Physics with Graph Networks - http://proceedings.mlr.press/v119/sanchez-gonzalez20a/sanchez-gonzalez20a.pdf