pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
21.51k stars 3.69k forks source link

Minibatching on large graph #2114

Open naveenkumarmarri opened 3 years ago

naveenkumarmarri commented 3 years ago

❓ Questions & Help

I'm trying to perform link prediction on large undirected graph. During minibatching phase I pass a single Data object with all the edges(1293978946) and all nodes(8595301). When I create a dataloader around the dataset, I get the entire graph during the minibatch iteration instead of batch of size batch_size.

class CustomDataset(InMemoryDataset):
    def __init__(self, root, transform=None, pre_transform=None):
        super(CustomDataset, self).__init__(root, transform, pre_transform)
        edges = torch.from_numpy(np.load('graph_undirected_edges.npy')) 
        node_embeddings = torch.from_numpy(np.load('node_embeddings.npy'))
        self.data, self.slices = self.collate([Data(num_nodes=node_embeddings.shape[0], x=node_embeddings, edge_index=edges)])
dataset = CustomDataset(root=None)
loader = DataLoader(dataset, batch_size=32)
for batch in loader:
      print(batch.edge_index.shape, batch.batch.shape, batch.x.shape)
      break

The above loop returns

torch.Size([2, 1293978946]) torch.Size([8595301]) torch.Size([8595301, 512])

Am I missing something ?

wsad1 commented 3 years ago

if you check len(dataset) it should be 1. The DataLoader has one graph to batch, so it will always return the whole graph. I am guessing, you want to batch nodes and not graphs. So, check out torch_geometric.data.NeighborSampler or torch_geometric.data.RandomNodeSampler. Note: I am new to the library, so this might not be the best answer.

rusty1s commented 3 years ago

@wsad1 is right: The DataLoader is only useful in case you want to batch multiple graphs together, not for splitting up/sampling a large graph. In case you want to train on a large graph, you have to make use of sampling, e.g., via NeighborSampler. An example of link prediction on large graphs can be found here.