pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
21.48k stars 3.69k forks source link

Add NeighborSampler support for more Conv layer #430

Open w21180239 opened 5 years ago

w21180239 commented 5 years ago

🚀 Feature

I notice that currently only SAGEConv and PointConv support NeighborSampler in PyG, could you make NeighborSampler supported for more Conv layer like GATConv.

Motivation

In fact, mant dataset in reality is a large graph that has to be corperated with mini-batch train.

rusty1s commented 5 years ago

Good request! GATConv and some others like NNConv or EdgeConv should be easily doable. However, there are a few operators which cannot incorperate bipartite graphs, e.g., GCNConv.

w21180239 commented 5 years ago

I wonder how long will it probably takes to add NeighborSampler support for more Conv layer.

rusty1s commented 5 years ago

Well, as long as it takes :P Feel free to submit a PR to help out :) There are still a number of uncertainties that I need to look at, e.g. GATConv is not well-defined for bipartite graphs due to the attention scores being computed based on x_i and x_j, but target node features x_i are not necessarily defined for bipartite graphs.

rusty1s commented 5 years ago

GATConv now supported, see here.

MrLinNing commented 5 years ago

@rusty1s Hi, I want to reproduce the GraphSAGE paper using the NeighborSampler Class. However, I meet the error when I run your NeighborSampler demo python file. python reddit.py Output:

Traceback (most recent call last):
  File "reddit.py", line 90, in <module>
    loss = train()
  File "reddit.py", line 71, in train
    out = model(data.x.to(device), data_flow.to(device))
  File "/home/paper/GNN/env/lib/python3.5/site-packages/torch/nn/modules/module.py", line 493, in __call__
    result = self.forward(*input, **kwargs)
  File "reddit.py", line 32, in forward
    x = F.relu(self.conv1((x, None), data.edge_index, size=data.size))
  File "/home/paper/GNN/env/lib/python3.5/site-packages/torch/nn/modules/module.py", line 493, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/paper/GNN/env/lib/python3.5/site-packages/torch_geometric/nn/conv/sage_conv.py", line 62, in forward
    x = x.unsqueeze(-1) if x.dim() == 1 else x
AttributeError: 'tuple' object has no attribute 'dim'

My pytorch version is 1.10 and I can normally run the benchmark/kernel/main.py Can your help me? Thank you!

rusty1s commented 5 years ago

Sure, either update to torch-geometric master by installing from source, or replace the conv calls in reddit.py by

...
x = F.relu(self.conv1(x, data.edge_index, size=data.size))
x = F.dropout(x, p=0.5, training=self.training)
data = data_flow[1]
x = self.conv2(x, data.edge_index, size=data.size)
...

Sorry for the inconveniences!

MrLinNing commented 5 years ago

Great!!! @rusty1s Thank you very much! Actually, I find the GraphSAGE model in benchmark/kernel/main.py does not have the NeighborSampler function. How to add the NeighborSampler into the 'benchmark/kernel/main.py' ?

rusty1s commented 5 years ago

The kernel benchmark scripts process a graph at once, so we do not have a NeighborSampler there. If you want to use the NeighboarSampler in a mini-batch scenario, you need to do something like:

for data in batch_loader:
    sampler = NeighborSampler(data, ...)
    for data_flow in sampler():
         model(data.x, data_flow)
MrLinNing commented 5 years ago

Sorry @rusty1s , I try it as your method to train the MUTAG dataset, which has not the data.train_mask or 'data.test_mask'. So, I divide it like this: image

image

I find this problem image Can you help me?

rusty1s commented 5 years ago

Please note that the kernel benchmark script performs graph classification in a mini-batch graph scenario, while the NeighborSampler performs mini-batching of nodes in a single graph. Now you have both worlds, mini-batching of graphs and mini-batching of nodes. This is far from trivial:

  1. You still need to train against the target of your graph: data.y, not data.y[data_flow.n_id].
  2. Your network architecture needs to be modified to work on DataFlow objects, see the reddit.py for a simple example. Afterwards, batched nodes needs to be pooled together.
MrLinNing commented 5 years ago

Oh, it' my fault, I just want to knows, how to use reddit.py to train the multiple graph datasets, like MUTAG and PROTEINS. Your know that the TUDataset has not train_mask or test_mask function as reddit dataset. So, how to train them use your reddit.py ? Thank your!!!

rusty1s commented 5 years ago

Ok. So the following code is untested, but it should convey the general idea:

class SAGENet(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(SAGENet, self).__init__()
        self.conv1 = SAGEConv(in_channels, hidden_channels, normalize=False)
        self.conv2 = SAGEConv(hidden_channels, hidden_channels, normalize=False)
        self.lin1 = Linear(hidden_channels, hidden_channels)
        self.lin2 = Linear(hidden_channels, out_channels)

    def forward(self, x, data_flow, batch):
        data = data_flow[0]
        x = x[data.n_id]
        x = F.relu(self.conv1(x, data.edge_index, size=data.size))
        data = data_flow[1]
        x = F.relu(self.conv2(x, data.edge_index, size=data.size))
        x global_mean_pool(x, batch[data_flow.n_id])
        x = F.relu(self.lin1(x))
        x = self.lin2(x)
        return F.log_softmax(x, dim=1)

for data in loader:
    sampler = NeighborSampler(...)
    for data_flow in sampler():
        optimizer.zero_grad()
        out = model(data.x.to(device), data_flow.to(device), data.batch.to(device))
        loss = F.nll_loss(out, data.y.to(device))
        loss.backward()
        optimizer.step()

Note that you only use a subset of the nodes in the graph to perform graph classification of the whole graph.

MrLinNing commented 5 years ago

It seems not right ....

image

image image .... image

rusty1s commented 5 years ago

Yes, you are right. You need to upgrade to PyG master by installing from source to fix this.

I tested the following code and it now works as expected:

import torch
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.data import DataLoader, NeighborSampler
from torch_geometric.datasets import TUDataset
from torch_geometric.nn import global_mean_pool, SAGEConv

dataset = TUDataset('/tmp/MUTAG', name='MUTAG').shuffle()
loader = DataLoader(dataset, batch_size=128)

class SAGENet(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(SAGENet, self).__init__()
        self.conv1 = SAGEConv(in_channels, hidden_channels, normalize=False)
        self.conv2 = SAGEConv(hidden_channels, hidden_channels,
                              normalize=False)
        self.lin1 = Linear(hidden_channels, hidden_channels)
        self.lin2 = Linear(hidden_channels, out_channels)

    def forward(self, x, data_flow, batch):
        data = data_flow[0]
        x = x[data.n_id]
        x = F.relu(self.conv1(x, data.edge_index, size=data.size))
        data = data_flow[1]
        x = F.relu(self.conv2(x, data.edge_index, size=data.size))
        x = global_mean_pool(x, batch[data_flow.n_id], size=batch.max() + 1)
        x = F.relu(self.lin1(x))
        x = self.lin2(x)
        return F.log_softmax(x, dim=1)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SAGENet(dataset.num_features, 64, dataset.num_classes).to(device)

for data in loader:
    sampler = NeighborSampler(data, size=[25, 10], num_hops=2,
                              batch_size=data.num_graphs * 10, shuffle=True,
                              add_self_loops=True)

    for data_flow in sampler():
        out = model(
            data.x.to(device), data_flow.to(device), data.batch.to(device))
        loss = F.nll_loss(out, data.y.to(device))
MrLinNing commented 5 years ago

How to install it from source ?

Actually, I have already tried the following instructions before running the python script:

pip uninstall torch-geometric
pip install torch-geometric
rusty1s commented 5 years ago

Uninstall via

pip uninstall torch-geometric

clone the repo and run

python setup.py install
MrLinNing commented 5 years ago

Hi, @rusty1s
The edge numbers, 4732, in your class Citeseer do not equal to the edge numbers, 9104/2 =4552, in GCN paper. Is there any processing step in your class Citeseer ?

image

image

rusty1s commented 5 years ago

Hi, this is highly related to #343.

johnny12150 commented 4 years ago

So for now, it is still not possible to use GATConv with NeighborSampler? Or is there any workaround?

Edit: I have tried it I got AssertionError: Static graphs not supported inGATConv``

rusty1s commented 4 years ago

That error is not related to the NeighborSampler. In fact, with PyG 1.6.0, most GNN can now be used with the NeighborSampler API. For GATConv, you can find an example here.

johnny12150 commented 4 years ago

Oh, I fixed the problem which was caused by the wrong input shape x into GATConv.

rusty1s commented 4 years ago

Perfect :)