pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
21.11k stars 3.63k forks source link

ToUndirect and is_undirect does not match on heterogeneous graphs #4596

Closed Zartris closed 2 years ago

Zartris commented 2 years ago

🐛 Describe the bug

When running T.ToUndirect()(data) on a heterogeneous graph you will still get false for running is_undirect()

from torch_geometric.datasets import OGB_MAG
import torch_geometric.transforms as T

dataset = OGB_MAG(root='data', preprocess='metapath2vec')
data = dataset[0]
data .is_undirected()
data = T.ToUndirected()(data)
data .is_undirected()

Found it while going through the example on heterogeneous-graph-learning section in docs.

Environment

rusty1s commented 2 years ago

Thanks for reporting. We currently do not have an elegant way to check for reverse edges in HeteroData, so we cannot check for is_undirected at the moment. This an open TODO documented here. Will try to look into this. For now, please don't make use of is_undirected on HeteroData objects :)

Padarn commented 2 years ago

Something like

    @staticmethod
    def _rev_key(key: Tuple[str, str, str]) -> Tuple[str, str, str]:

        edge_type = key[1]
        if edge_type.startswith("rev_"):
            edge_type = edge_type[4:]
        else:
            edge_type = "rev_" + edge_type

        return key[-1], edge_type, key[0]

    def is_undirected(self) -> bool:
        def hetro_is_undirected(key: Tuple[str, str, str], store: EdgeStorage):
            if store.is_undirected():
                return True
            elif store.is_bipartite():
                key = self._rev_key(key)
                reverse_store = self._edge_store_dict.get(key, None)
                if reverse_store is not None:
                    return is_undirected(torch.cat([store.edge_index, reverse_store.edge_index], dim=1))
            return False
        return all([hetro_is_undirected(key, store) for (key, store) in self._edge_store_dict.items()])

in HeteroData might do the trick?

Padarn commented 2 years ago

I tried out the above, and it doe seem to do the job, but of course, it is a bit inefficient (multiple sorts).

rusty1s commented 2 years ago

This currently requires that reverse edge types are marked via rev_*, which is a somewhat arbitrary constraint. I think a better approach might be to convert the graph into a homogeneous one first and then re-use the to_undirected functionality. Let me see if I can come up with a quick fix.

Padarn commented 2 years ago

True 👍