Closed Zartris closed 2 years ago
Thanks for reporting. We currently do not have an elegant way to check for reverse edges in HeteroData
, so we cannot check for is_undirected
at the moment. This an open TODO documented here. Will try to look into this. For now, please don't make use of is_undirected
on HeteroData
objects :)
Something like
@staticmethod
def _rev_key(key: Tuple[str, str, str]) -> Tuple[str, str, str]:
edge_type = key[1]
if edge_type.startswith("rev_"):
edge_type = edge_type[4:]
else:
edge_type = "rev_" + edge_type
return key[-1], edge_type, key[0]
def is_undirected(self) -> bool:
def hetro_is_undirected(key: Tuple[str, str, str], store: EdgeStorage):
if store.is_undirected():
return True
elif store.is_bipartite():
key = self._rev_key(key)
reverse_store = self._edge_store_dict.get(key, None)
if reverse_store is not None:
return is_undirected(torch.cat([store.edge_index, reverse_store.edge_index], dim=1))
return False
return all([hetro_is_undirected(key, store) for (key, store) in self._edge_store_dict.items()])
in HeteroData
might do the trick?
I tried out the above, and it doe seem to do the job, but of course, it is a bit inefficient (multiple sorts).
This currently requires that reverse edge types are marked via rev_*
, which is a somewhat arbitrary constraint. I think a better approach might be to convert the graph into a homogeneous one first and then re-use the to_undirected
functionality. Let me see if I can come up with a quick fix.
True 👍
🐛 Describe the bug
When running T.ToUndirect()(data) on a heterogeneous graph you will still get false for running is_undirect()
Found it while going through the example on heterogeneous-graph-learning section in docs.
Environment
conda
,pip
, source): pip