pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
21.11k stars 3.63k forks source link

Pair of graphs (similarity score regression) #511

Closed dchang56 closed 1 year ago

dchang56 commented 5 years ago

Hello,

I'm thinking about a regression task where you take in a pair of graphs and try to predict their similarity score based on the node features and the graph topology.

I'm wondering what the best way to approach this would be. Two main approaches I've thought of are:

  1. create a virtual summary node for each graph and create an edge between the two summary nodes in order to treat the pair of graphs as one graph.
  2. Keep the graphs separate; apply some graph network on each graph individually, and concatenate their respective output for final regression layer like so: out1 = model(edge_index_1, x_1) out2 = model(edge_index_2, x_2) pred = linear(torch.concat([out1, out2], dim=-1) #linear is (2*nhid, 1) loss = mse_loss(pred, data.y) ...and so on

I'd really appreciate any feedback/insight/suggestions on this problem! -which way do you think is better, and why? -what are some of the details I should consider?

rusty1s commented 5 years ago

This is a really interesting application!

  1. I like this one, as it can send messages between the two graphs during neighborhood aggregation. A similar approach can be found here.
  2. This one should work too (but I guess it is less powerful). It is, e.g., already used here (in combination with a pair-wise node histogram).

Both approaches should be straightforward to implement in PyG, but it is a bit tricky to get it right in a mini-batch scenario :)

dchang56 commented 5 years ago

Thanks for the quick response! I just finished generating two different versions of the graph dataset -one where the pair of graphs per example is connected by an edge between their respective summary nodes (treat a pair of graphs as one graph) -one where graph a and graph b are separate (separate DataLoader_a for graph_a's, and DataLoader_b for graph_b's); this seemed simpler than doing Data(edge_index_1=grapha.edge_index, x_1=grapha.x, edge_index_2=graphb.edge_index, x_2=graphb.x, y=label) and writing a custom collate function to batch the graphs simultaneously.

I'm going to try implementing a model like the recent pooling examples (use GIN/GAT as conv, and TopK or SAGpool as pooling layers) and see what happens. I'm pretty new to graph networks, so this is going to be a learning experience!

When you say approach 2 is less powerful than 1, do you mean because it doesn't explicitly model interaction between the graphs? Intuitively it makes sense, but I'm not sure the comparison is clear-cut. In approach 1, we get to explicitly model information flow between graphs, and output one vector representation for the graph through pooling, which can be considered a representation of the relationship of the graph pair as a whole; in approach 2, we process the graphs individually, but get 2 vector representations which can then be further transformed for regression. I think it's also complicated by the fact that pooling is involved (i.e. what if we exclude summary nodes/connection between graphs during pooling?). There are also instances where we only have graph a, paired with sort of a placeholder node for graph b when graph b isn't available (due to lack of information in graph construction process). Does pooling generalize to cases where the number of nodes in a graph is 1?

Anyway, we basically want the model to learn that examples with two very different graphs receive a score close to 0, and examples with two very similar graphs receive a score close to 1. (fyi the node features are 500-d vectors kind of like concept embeddings, and the graphs are a semantic representation of the original documents from which they were constructed. The edges represent the existence of a semantic relation between concepts). And I think approach 1 induces a bias that helps the model learn that during graph convolution/pooling, while approach 2 outsources that to subsequent steps (i.e. pairwise regression of outputs from graph a and graph b). Does that make sense? I'd love to know what you think!

rusty1s commented 5 years ago

I just finished generating two different versions of the graph dataset

This makes a lot of sense. Looks good to me :)

I would start as simple as possible implementing approach 2, and would only add pooling when your base model is working but may underfit. A simple GINConv or GATConv baseline with global_{}_pool should work just fine. As you said, approach 1 gets much harder to implement if one wants to use pooling. One basically needs to remove and add the summary nodes back in before and after pooling respectively to prevent that summary nodes do not get clustered. Pooling should generally work in isolation for pairs of graphs, as long as there is no edge connecting them and should work in cases |V|=1 (feel free to submit an issue if this is not the case :P).

I've never played with summary nodes myself, but the Graph Matching paper indicates that message passing between graphs influences the respective node embeddings in a positive way. Node embeddings are not built in isolation, but rather dependent on each other, and can, e.g., attend to specific nodes in the graph based on the other graph (and vice versa). The concept of using summary nodes for this kind of task instead of pair-wise node connections is interesting, as it pushes running time down from O(N^2) to O(N).

dchang56 commented 5 years ago

Thanks! That all makes sense, and I was able to do a quick implementation for approach 1 (without removing and adding back for pooling), which gave terrible results (I haven't had a chance to tune hyperparams or experiment with different architectures, but it's struggling to get above 0.25 pearson correlation in any setting so far, I'll have to look more into it). Tomorrow I'll give approach 2 a shot and see if treating the graphs separately makes more sense. If I find anything interesting or have questions, I'll post it here. I can confirm that pooling does work when |V|=1 (seems like it would be true as long as ratio>0.5 and min_score<1 maybe?).

Somewhat of a trivial question: would normalizing the labels to be in range [0, 1] (from original score range between 0 and 5) make any difference?

Also, the only graph regression example (QM9) doesn't seem to be working in its current form. I'll give it another try tomorrow. Maybe it's a version problem.

rusty1s commented 5 years ago

min_score adds at least one node to the graph (independent of its score), and ratio always rounds up, so |V|=1 should work in both cases for all value ranges.

What is not working for you in the QM9 example? Runs just fine for me.

dchang56 commented 5 years ago

I'll check it in a bit.

Another question: if I do approach 2 and decided to treat graph a and graph b separately, how should I structure the network? Should I initialize and apply Net1 and Net2 to graphs a and graphs b, then combine their final outputs for a classifier? Or initialize and apply only one Net to both input graphs? For the former, does it make sense to feed the dataset in twice, once with the graphs swapped (graph b to Net1 and graph a to Net2). I'd really appreciate your suggestions!

rusty1s commented 5 years ago

Usually, one uses only a single Net and applies it to both graphs separately (siamese style). This makes sense if we are also interested in the final node embeddings. IMO, the other case should work just as well, but it may make sense to initialize the networks the same:

self.gnn2.load_state_dict(self.gnn1.state_dict())

As you said, this approach is not really symmetric, and feeding the data in twice might be a good option to force this.

dchang56 commented 5 years ago

Great! So it's something I'll just have to try out. Does this look right to you? -Using only a single net: init: self.net = Net(args) forward: x1 = self.net(x1) x2 = self.net(x2) out = classifier(x1, x2)

-Using two: init: self.net1 = Net1() self.net2 = Net2() self.net2.load_state_dict(self.net1.state_dict()) forward: x1 = self.net1(x1) x2 = self.net2(x2) out = classifier(x1, x2)

rusty1s commented 5 years ago

That's correct :)

dchang56 commented 5 years ago

Hi, I was able to implement a rough model of the approach where we process graph a and graph b separately with the same network, and do subsequent regression.

Below is the code the describes the network architecture and the training procedure.

It's able to get to 90% training accuracy, but the validation accuracy stays below 50%. I'll be experimenting with different layers and hyperparameters, but I would really appreciate it if you could take a look at the code below and tell me if there's any problem with it. Thank you so much for your help!

class GCNNet(nn.Module):
    def __init__(self, num_features, n_hidden):
        super(GCNNet, self).__init__()
        self.conv1 = GATConv(num_features, n_hidden)
        self.conv2 = GATConv(n_hidden, n_hidden)
        self.pool = SAGPooling(n_hidden, min_score=0.01, GNN=GCNConv)
        self.conv3 = GCNConv(n_hidden, 32)
    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = F.relu(self.conv1(x, edge_index))
        x = F.dropout(x, p=0.5, training=self.training)
        x = F.relu(self.conv2(x, edge_index))
        x = F.dropout(x, p=0.5, training=self.training)
        x, edge_index, _, batch, perm, score = self.pool(x, edge_index, None, batch)
        x = F.relu(self.conv3(x, edge_index))
        x = global_add_pool(x, batch)
        return x

class PairClassifier(nn.Module):
    def __init__(self, dim_model, dim_hidden, num_classes):
        super(PairClassifier, self).__init__()
        self.net = Sequential(Linear(dim_model, dim_hidden),
                             ReLU(),
                             nn.Dropout(0.4),
                             Linear(dim_hidden, dim_hidden),
                             ReLU(),
                             nn.Dropout(0.4),
                             Linear(dim_hidden, num_classes))
    def forward(self, x):
        return self.net(x)

class Net(nn.Module):
    def __init__(self, num_features, n_hidden, num_classes):
        super(Net, self).__init__()
        self.gcnnet = GCNNet(num_features, n_hidden)
        self.classifier = PairClassifier(n_hidden, 64, num_classes)
    def forward(self, data_a, data_b):
        x1 = self.gcnnet(data_a)
        x2 = self.gcnnet(data_b)
        x = torch.cat([x1, x2, x1-x2, x1*x2], dim=-1)
        out = self.classifier(x).squeeze()
        return out

num_features = 500
n_hidden = 128
num_classes = 1

device = 'cuda'
model = Net(num_features=num_features, n_hidden=n_hidden, num_classes=num_classes).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)

# regression
from scipy.stats import pearsonr
def train(loader_a, loader_b):
    model.train()
    total_loss = 0
    for data_a, data_b in zip(loader_a, loader_b):
        data_a, data_b = data_a.to(device), data_b.to(device)
        optimizer.zero_grad()
        out = model(data_a, data_b)
        loss = F.mse_loss(out, data_a.y)
        loss.backward()
        total_loss = data_a.num_graphs * loss.item()
        optimizer.step()
    return total_loss / len(loader_a.dataset)
def test(loader_a, loader_b):
    model.eval()
    running_corr = 0
    for data_a, data_b in zip(loader_a, loader_b):
        data_a, data_b = data_a.to(device), data_b.to(device)
        pred = model(data_a, data_b)
        corr = pearsonr(pred.detach().cpu().numpy(), data_a.y.detach().cpu().numpy())[0]
        running_corr += corr
    return running_corr / (len(loader_a))

for epoch in range(1, 201):
    loss = train(train_a_dataloader, train_b_dataloader)
    train_acc = test(train_a_dataloader, train_b_dataloader)
    test_acc = test(dev_a_dataloader, dev_b_dataloader)
    print('epoch {} \t loss {} \t train acc: {} \t val acc: {}'.format(epoch, loss, train_acc, test_acc))

Also, regarding the QM9 example, I think something has changed about the dataset itself when downloading/extracting from the source. This is the error I get:

Downloading http://www.roemisch-drei.de/qm9.tar.gz
Extracting ./raw/qm9.tar.gz
---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
<ipython-input-2-b1f4decfd163> in <module>
     40 
     41 # Normalize targets to mean = 0 and std = 1.
---> 42 mean = dataset.data.y[:, target].mean().item()
     43 std = dataset.data.y[:, target].std().item()
     44 dataset.data.y[:, target] = (dataset.data.y[:, target] - mean) / std

IndexError: too many indices for tensor of dimension 1

Perhaps it's not downloading a version with the original 12 or 13 target series and only getting one such that there is no second dimension in data.y.

rusty1s commented 5 years ago

Looks good to me. You should consider replacing the global add with global mean or global max, add attention heads to the GATConv, and replace the mse with binary cross entropy.

dchang56 commented 5 years ago

What's the reason for replacing mse with cross entropy? It's a regression task, trying to predict a score on a continuous scale of 0-5.

rusty1s commented 5 years ago

Ah ok, sorry. You are correct.

dchang56 commented 5 years ago

What would be the reasoning behind replacing global add with global mean/max? I was under the impression that sum pooling works better than mean pooling based on the GIN paper.

rusty1s commented 5 years ago

Well, the GIN paper uses global mean pooling in most of their experiments too. I agree with you that global add pooling should be more expressive, but I found that global mean and global max pooling works better in most cases. I do not have a rigorous proof for this, but I do believe that global add pool can be numerical unstable and may highly overfit on your training data since it cannot really generalize to unseen graph sizes.

alexcdot commented 4 years ago

Seems like QM9 only works when you have rdkit installed

https://github.com/rusty1s/pytorch_geometric/issues/844#issuecomment-561322235

rusty1s commented 4 years ago

I will look into it!

rusty1s commented 1 year ago

Closing this issue as QM9 also works without rdkit.