pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
21.17k stars 3.64k forks source link

Induced Edge Subgraph Support for Data and HeteroData. #6117

Open kkranen opened 1 year ago

kkranen commented 1 year ago

🚀 The feature, motivation and pitch

I'm working on use cases where I would like my train/valid/test dataloaders to have a different set of available edges base don what has already been seen (ie, train edges are seen in validation set, validation edges and train edges are seen in test set).

In order to do this, I need to have support for inducing edge subgraphs, similar to DGL's edge subgraph: https://docs.dgl.ai/en/0.9.x/generated/dgl.edge_subgraph.html#dgl.edge_subgraph

Alternatives

No response

Additional context

No response

rusty1s commented 1 year ago

Thanks for the issue @kkranen. I think this should be straightforward to get in, since all we need to do is mask/index select edge_index and any edge features, e.g.:

edge_index = edge_index[:, mask]
edge_attr = edge_attr[mask]

We can either integrate this as a utility function or in Data/HeteroData directly. What do you prefer?

kkranen commented 1 year ago

I think it's likely better as a member function of Data/HeteroData. On a slightly other note, it may also be good to split any other user-included features other than edge_attr (like categorical data). I know that this is slightly dangerous in the sense that other attributes are not guaranteed to be Tensors, but what I have as a fix on my side looks like this:

def edge_subgraph_pyg(graph, edge_ids):
    """
    graph: Target graph to create edge subgraph of
    edge_ids: Edge ids (indices or boolean tensor) to use for the subgraph
    """
    output = copy.deepcopy(graph)
    for edge, data in edge_ids.items():
        for name in output[edge].keys():
            if isinstance(output[edge][name],Tensor):
                output[edge][name] = output[edge][name][...,edge_ids[edge]]
    return output
rusty1s commented 1 year ago

We usually do this in PyG via the data.is_edge_attr function, so it would become something like:


def edge_subgraph(self, mask):
    out = copy.copy(self)
    for key, value in self.items():
         if self.is_edge_attr(key):
             out[key] = self[key][mask]
    return out