pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
20.92k stars 3.61k forks source link

Neighborloader with torch_geometric.data.Dataset #3804

Open MATTEMAINA opened 2 years ago

MATTEMAINA commented 2 years ago

🚀 Feature

Implement Neighborloader that accept not only one Data object at time, but a Dataset object too. The goal is to manage entire dataset with only one call of neighborloader instead manage each datum of dataset every time

Motivation

I working with temporal graphs, and i use so much The Neighborloader to separate my graph in training and validation, but it can process only one Data object at time instead of a Dataset (since i have snapshots of a evolving graph)

rusty1s commented 2 years ago

I'm happy to learn more about your use-case. The current strategy for this is to convert your dataset into a single data object and pass it to NeighborLoader, i.e.

from torch_geometric.data import Batch
from torch_geometric.loader import NeighborLoader

data = Batch.from_data_list(dataset)
loader = NeighborLoader(data, ...)
MATTEMAINA commented 2 years ago

i tried this approach, but i could notice that i cannot more use the 'input_nodes' parameter, i use it to define the set of nodes to use in training and validation. how have i to modify the parameter 'input_nodes' for having the same result?

rusty1s commented 2 years ago

That depends. If training, validation nodes are described for each data objects via masks, then these masks can be used on the Batch object as well:

data = Batch.from_data_list(dataset)
train_loader = NeighborLoader(data, input_nodes=data.train_mask, ...)

If you have designated train_dataset and val_dataset, you do not need to make use of input_nodes at all:

train_loader = NeighborLoader(Batch.from_data_list(train_dataset), ...)
val_loader = NeighborLoader(Batch.from_data_list(val_dataset), ...)
MATTEMAINA commented 2 years ago

Perfect if i understand correctly, Thank this approach the NeighborLoader return a batch of single objact Data as if i use Dataloader(dataset) ?

rusty1s commented 2 years ago

It will return a subset of your graph data. As such, the returned data object is of type Data rather than of type Batch. The only real difference between the two is the missing batch vector.