pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
21.14k stars 3.63k forks source link

Edge indices when batching heterogeneous graphs #7138

Open martinbchnr opened 1 year ago

martinbchnr commented 1 year ago

šŸ› Wrong edge indices when batching heterogeneous graphs?

Hi everybody, thanks for the amazing library so far!

I am currently facing an issue regarding the generated edge_index when batching HeteroData objects either via a loader.DataLoader or via Batch.from_data_list():

a = HeteroData(author={'x': torch.randn((2,10)), 'edge_index': torch.tensor([[0,1], [1,0]])}, book={'x': torch.randn((6,10))})
b = HeteroData(author={'x': torch.randn((3,10)), 'edge_index': torch.tensor([[0,1,2], [1,0,1]])}, book={'x': torch.randn((6,10))})

loader = DataLoader([a,b], batch_size=2)
batch = next(iter(loader))

Next, I want to access the combined edge indices after batching, which should hold renumbered edge indices as with homogeneous graphs:

batch['author'].batch
> tensor([0, 0, 1, 1, 1])  # makes sense

batch['author'].edge_index
> tensor([[0, 1, 0, 1, 2],
[1, 0, 1, 0, 1]])  # does not make sense

I might be wrong but should the edge indices not be reindexed in accordance with the batch variable? I expected something like this instead (similar to the standard batching procedure of homogeneous graphs):

batch['author'].edge_index
> tensor([[0, 1, 2, 3, 4],
               [1, 0, 3, 2, 3]])

Imagine you would perform a coalesce over the edge_index before. This would basically change the graph connectivity as coalesce does not utilize the batch variable.

Thank you in advance!

Environment

EdisonLeeeee commented 1 year ago

Hi @martinbchnr, sorry for the late response.

I believe there may be some misunderstandings regarding the usage of HeteroData. In this case, edge_index should be defined between different node types rather than being an attribute of a single node type. Therefore, a corrected example would be:


a = HeteroData(author={'x': torch.randn((2,10))}, book={'x': torch.randn((6,10))})
a['author', 'book'].edge_index = torch.tensor([[0,1], [1,0]])
b = HeteroData(author={'x': torch.randn((3,10))}, book={'x': torch.randn((6,10))})
b['author', 'book'].edge_index = torch.tensor([[0,1,2], [1,0,1]])

loader = DataLoader([a,b], batch_size=2)
batch = next(iter(loader))

>>> batch['author', 'book'].edge_index
tensor([[0, 1, 2, 3, 4],
        [1, 0, 7, 6, 7]])
martinbchnr commented 1 year ago

Thanks for your answer @EdisonLeeeee,

I think we got another misunderstanding here :laughing:

I specifically added an edge_index that just belongs to the node type author. I changed my own code by now to the following, which works well:

a = HeteroData(author={'x': torch.randn((2,10))},
               author__to__author={'edge_index': torch.tensor([[0,1], [1,0]])}, 
               book={'x': torch.randn((6,10))})
b = HeteroData(author={'x': torch.randn((3,10))},
               author__to__author={ 'edge_index': torch.tensor([[0,1,2], [1,0,1]])}, 
               book={'x': torch.randn((6,10))})

In general, I think it would be a good automated feature for the next release if

data = HeteroData()
data['author'].edge_index = ...

would be directly mapped to the equivalent of

data = HeteroData()
data['author, 'author'].edge_index = ...

This should work since edge_index is a protected keyword I guess.

EdisonLeeeee commented 1 year ago

Sounds good! Will look into it. Thank you!

cc: @rusty1s