dmlc / dgl

Python package built to ease deep learning on graph, on top of existing DL frameworks.
http://dgl.ai
Apache License 2.0
13.5k stars 3.02k forks source link

[RFC] [DataLoader] Support iterating over node pairs in DataLoader #4444

Open BarclayII opened 2 years ago

BarclayII commented 2 years ago

🚀 Feature

Support iterating over node pairs in DataLoader.

Motivation

Currently there are two ways to perform link prediction evaluation with given positive and negative samples:

One is to have the positive and negative examples as a part of the graph, and iterate over the edges with as_edge_prediction_sampler, exclude them during sampling, and treat the evaluation as binary edge classification:

g = ...
pos_src, pos_dst = ...   # validation positive edges
neg_src, neg_dst = ...   # validation negative edges

num_edges = g.num_edges()
new_g = g.add_edges((torch.cat([pos_src, neg_src]), torch.cat([pos_dst, neg_dst])))
val_eid = torch.arange(num_edges, num_edges + pos_src.shape[0] + neg_src.shape[0])
# Set the label of edges
new_g.edata['label'] = torch.zeros(new_g.num_edges())
new_g.edata['label'][num_edges:num_edges + pos_src.shape[0]] = 1
sampler = dgl.dataloading.as_edge_prediction_sampler(sampler, exclude=val_eid)
dataloader = dgl.dataloading.DataLoader(new_g, val_eid, sampler)
for input_nodes, pair_graph, blocks in dataloader:
    score = model(pair_graph, blocks, blocks[0].srcdata['x'])
    label = pair_graph.edata['label']
    acc = compute_accuracy(score, label)

This is quite complicated, and will be even more complicated if one evaluates on heterogeneous graphs. It is also neither efficient nor scalable since the edges will have to be excluded for every sampling operation, and the validation edge IDs are given as a tensor during edge exclusion (which doesn't work for distributed training).

Another is to evaluate link prediction by computing the node representations and then computing the scores from the incident node representations. Unfortunately this is not possible for subgraph-representation-based link prediction methods such as SEAL.

Alternatives

If #4441 is implemented, one can create a single graph with all the training edges, positive validation edges, and negative validation edges. Then during training and validation, one creates a sampler that only samples on the training edges. One treats validation as evaluating binary edge classification.

g = ...
pos_src, pos_dst = ...   # validation positive edges
neg_src, neg_dst = ...   # validation negative edges

# This new graph construction code can be done during preprocessing, hence usable for
# distributed training.
num_edges = g.num_edges()
new_g = g.add_edges((torch.cat([pos_src, neg_src]), torch.cat([pos_dst, neg_dst])))
new_g.edata['train_mask'] = torch.zeros(new_g.num_edges(), dtype=torch.bool)
new_g.edata['train_mask'][:num_edges] = True
train_eid = torch.arange(0, num_edges)
val_eid = torch.arange(num_edges, num_edges + pos_src.shape[0] + neg_src.shape[0])
new_g.edata['label'] = torch.zeros(new_g.num_edges())
new_g.edata['label'][num_edges:num_edges + pos_src.shape[0]] = 1

neighbor_sampler = dgl.dataloading.NeighborSampler([5, 10, 15], include_mask='train_mask')
train_sampler = dgl.dataloading.as_edge_prediction_sampler(
    neighbor_sampler, exclude='reverse_id', reverse_eids=..., negative_sampler=...)
train_dataloader = dgl.dataloading.DataLoader(new_g, train_eid, train_sampler)
for input_nodes, pair_graph, neg_pair_graph, blocks in train_dataloader:
    ...
val_sampler = dgl.dataloading.as_edge_prediction_sampler(neighbor_sampler)
val_dataloader = dgl.dataloading.DataLoader(new_g, val_eid, val_sampler)
for input_nodes, pair_graph, blocks in val_dataloader:
    # Validate as in binary edge classification
    ...

This is however inconvenient for validating and testing on new edges during deployment when the node pairs to predict may vary from time to time, because one needs to create a new graph every time the node pairs to predict changes.

Pitch

The user experience will look like the following:

g = ...
pos_src, pos_dst = ...   # validation positive edges
neg_src, neg_dst = ...   # validation negative edges

train_eid = torch.arange(0, num_edges)
neighbor_sampler = dgl.dataloading.NeighborSampler([5, 10, 15])
train_sampler = dgl.dataloading.as_edge_prediction_sampler(
    neighbor_sampler, exclude='reverse_id', reverse_eids=..., negative_sampler=...)
train_dataloader = dgl.dataloading.DataLoader(g, train_eid, train_sampler)

# New function: as_node_pair_prediction_sampler
val_sampler = dgl.dataloading.as_node_pair_prediction_sampler(
        neighbor_sampler, prefetch_label=['label'])
# The DataLoader needs to support iterating over an Nx2 tensor.
val_src_dst = torch.stack([torch.cat([pos_src, neg_src]), torch.cat([pos_dst, neg_dst])], 1)
val_labels = torch.cat([torch.ones_like(pos_src), torch.zeros_like(neg_src)])
# The DataLoader also needs to take in an Nx2 tensor, or a dict of edge types
# and Nx2 tensors.
val_dataloader = dgl.dataloading.DataLoader(
    g,
    val_eid,
    val_sampler)
# Use attach_data because the storage is neither node storage nor edge storage.
# The key 'label' must match the one in prefetch_label above.
val_dataloader.attach_data('label', val_labels)
for input_nodes, pair_graph, blocks in val_dataloader:
    pass

We need a sampler wrapper function as_node_pair_prediction_sampler, similar to as_edge_prediction_sampler. The signature will go as follows:

def as_node_pair_prediction_sampler(sampler, negative_sampler=None, prefetch_label=None):
    pass

I'm not sure if we need edge exclusion here. At least edge exclusion seems unnecessary for link prediction validation because the graph will not contain the validation edges anyway.

The function as_node_pair_prediction_sampler will create a NodePairPredictionSampler object, whose sample method will take in the graph as well as a pair of (1) Nx2 tensor and (2) the indices to the entire node pair set (or a dict of edge types and Nx2 tensors). The indices will be used for prefetching labels and will be assigned to the pair graph by assigning to pair_graph.edata[dgl.EID].

def sample(self, g, indices):
    src_dst, label_indices = indices
    # src_dst is an Nx2 tensor with source and target node IDs, or a dict of them for
    # heterogeneous graphs.
    pair_graph = construct_pair_graph_from(src_dst)
    seed_nodes = pair_graph.ndata[dgl.NID]
    input_nodes, _, subg = self.sampler.sample(g, seed_nodes)
    assign_lazy_features(pair_graph, label_indices, self.prefetch_label)
    return input_nodes, pair_graph, subg

The return value will be the same as as_edge_prediction_sampler.

The only issues I have with this UX are

I'm not sure if we have a better UX for this.

Additional context

This requirement surfaces from the implementation of SEAL on OGB datasets from @rudongyu .

EDIT (8/23): changes to the issues and clarified the behavior of as_node_pair_prediction_sampler.

mufeili commented 2 years ago

Will we still need to make negative edges part of the graph?

mufeili commented 2 years ago

cc @rudongyu @Ereboas If you have not checked this issue, it will be great if you can take a look and give some feedback.

BarclayII commented 2 years ago

Will we still need to make negative edges part of the graph?

In this case no.