pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
20.98k stars 3.62k forks source link

Edge classification on directed graphs #2094

Open albertodallolio opened 3 years ago

albertodallolio commented 3 years ago

Hi, I am currently working on a project focused on edge generation/classification for directed graphs. The data I have right now are the nodes of the graph of which I will need to generate and label the edges. Each node is represented by the class it belongs to. The edges that will connect these nodes will have to be classified amongst 5 different classes. I have three ideas in mind but I am writing to you for a suggestion:

The dataset I have is composed of roughly 180 graphs with at most 24 nodes each. My questions are: Do you think that a net similar to the one shown in link_pred.py file would be suitable for my problem? In your opinion, which approach do you think would work better given the few data I have in my dataset?

rusty1s commented 3 years ago

This is hard for me to tell. All your ideas seem absolutely reasonable to me. Utilizing additional information may either increase or decrease performance, dependent on whether this information is useful for the model to generalize better or whether it will induce over-fitting.

In general, I would recommend to start with the simplest approach and use it as a baseline. In this case, that would be a classic GCN (or Node2Vec) baseline. It seems that all your ideas can be build upon our link_pred.py example. You just may want to swap out the model itself (and the loss function in case you want to classifiy edges into certain classes).

albertodallolio commented 3 years ago

Thank you so much for your reply.

It seems that all your ideas can be build upon our link_pred.py example.

That is what I thought too. I went through the code and I have not pretty clear how the data are passed through the net. In this line you pass features x and the indexes of positive edges to the model, right? Why do data.train_pos_edge_index has dimensions 2 x 8976? Could you please clear my mind about that?

Thanks again.

rusty1s commented 3 years ago

data.train_pos_edge_index denotes the positive edges used for training, which is a subset of the original edge_index of shape [2, num_edges]. The remaining ones are used for validation and testing.

albertodallolio commented 3 years ago

Ok, thank you.

Just one last question. Which hints would you give me to translate link_pred.py in order to have a multiclass classification of the edges? I don't get how the data need to be passed to the net in the case of multiple different labels for the edges. I am new into GCN architectures, thanks so much for the help.

rusty1s commented 3 years ago

That depends on whether you simply want to classify edges, or want to classify missing edges. In the former case, this is really simple to implement, e.g.:

conv1 = GCNConv(in_channels, 64)
conv2 = GCNConv(64, 64)
mlp = MLP(128, num_classes)

src, dst = edge_index
x = conv1(x, edge_index).relu_()
x = conv2(x, edge_index).relu_()
edge_attr = torch.cat([x[src], x[dst]], dim=-1)  # Generate edge embedding
out = mlp(edge_attr)
loss = F.cross_entropy(out, y)
albertodallolio commented 3 years ago

Thank you so much for your help, it's been very useful.

mahadafzal commented 3 years ago

@rusty1s How would you go about classifying and labelling edges i.e. correctly detecting edges as in the link_pred example, and then classify them into respective labels (denoting the relationship type between nodes)?

rusty1s commented 3 years ago

You can have an additional loss/output for positive links that tries to predict the ground-truth edge type:

def decode(self, z, pos_edge_index, neg_edge_index):
     # Decode link prob as before
     link_prop = ...
     # Predict edge type
     link_pred = self.mlp(torch.cat([z[pos_edge_index[0]], z[pos_edge_index[1]], dim=-1))
     return link_prob, link_pred

def train():
    ...
    loss += F.cross_entropy(link_pred, ground_truth_edge_type)
    ...
mahadafzal commented 3 years ago

@rusty1s Thanks for the help! That makes sense. I am trying to predict between two edge types after decoding link prob as before. Could you help me out with how the mlp should be structured? Thanks!

rusty1s commented 3 years ago

The MLP just takes in source and destination node embeddings and maps it to out_channels number of classes, e.g.:

self.mlp = Sequential(Linear(2 * hidden_channels, hidden_channels), ReLU(), Linear(hidden_channels, num_classes)
coffeenmusic commented 2 years ago

That depends on whether you simply want to classify edges, or want to classify missing edges. In the former case, this is really simple to implement, e.g.:

conv1 = GCNConv(in_channels, 64)
conv2 = GCNConv(64, 64)
mlp = MLP(128, num_classes)

src, dst = edge_index
x = conv1(x, edge_index).relu_()
x = conv2(x, edge_index).relu_()
edge_attr = torch.cat([x[src], x[dst]], dim=-1)  # Generate edge embedding
out = mlp(edge_attr)
loss = F.cross_entropy(out, y)

I'm following your code example except with regression predicting distance between nodes (MSE loss). My loss does not decrease if the edge_attr cat function comes after the GCN (like your example), but if I concatenate before passing to the GCN my loss decreases. Do you have any insight here? I'm assuming I shouldn't concatenate source & destination input feature before passing to the GCN.

rusty1s commented 2 years ago

Can you show me an example on what you mean?

coffeenmusic commented 2 years ago

I really appreciate you taking time to respond. This is where I get my loss to decrease, but not if I apply the edge attribute concatenation after the GCN layers.

conv1 = GCNConv(in_channels*2, 64)
conv2 = GCNConv(64, 64)
mlp = MLP(64, num_classes)

src, dst = edge_index
edge_attr = torch.cat([x[src], x[dst]], dim=-1)  # Generate edge embedding
x = conv1(edge_attr, edge_index).relu_()
x = conv2(x, edge_index).relu_()

out = mlp(x)
loss = F.mse_loss(out, y)
rusty1s commented 2 years ago

Mh, this looks a bit wrong to me. You shouldn't use edge_attr as a node feature matrix. What happens if you replace GCNConv with SAGEConv in the original example? Hopefully, this resolves your issues of the loss not decreasing.

coffeenmusic commented 2 years ago

Thank you so much, it looks like it is working now! If you don't mind me asking, what made you think switching to SAGEConv would help?

rusty1s commented 2 years ago

It was just a basic intuition that the usage of skip connections/models that can preserve central node information might be necessary. Sadly, the GCNConv is not capable in doing so.