pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
21.28k stars 3.65k forks source link

Why the data.y (label) was used in unsupervised method? #2400

Closed wanglu2014 closed 2 years ago

wanglu2014 commented 3 years ago

🐛 Bug

图片 How the data.y was generated? Did optimize model with y is not self-supervised? https://github.com/rusty1s/pytorch_geometric/blob/master/examples/super_gat.py

rusty1s commented 3 years ago

SuperGAT is not an unsupervised method, but utilizes an auxiliary unsupervised loss for learning better attention coefficients (att_loss). This loss is directly computed inside SuperGATConv, see here.

wanglu2014 commented 3 years ago

SuperGAT is not an unsupervised method, but utilizes an auxiliary unsupervised loss for learning better attention coefficients (att_loss). This loss is directly computed inside SuperGATConv, see here.

Thank you for your kindly reply! Could you kindly recommend an unsupervised GNN method that could generate edge weight (to rank the importance of a node' neighbors to the node)?

rusty1s commented 3 years ago

Please avoid creating multiple issues with the same topic. I'm already trying to answer all questions/issues as early as I can.

For unsupervised learning with attention, you can basically use any unsupervised method with a GAT encoder, e.g., Deep Graph Infomax (examples/infomax.py) or link prediction (examples/link_pred.py). If you only care about finding the importance of a node's neighbors for a given prediction, you might be also interested in GNNExplainer (examples/gnn_explainer.py), which tries to find a subset of meaningful edges that do not change the outcome of predictions.

wanglu2014 commented 3 years ago

Please avoid creating multiple issues with the same topic. I'm already trying to answer all questions/issues as early as I can.

For unsupervised learning with attention, you can basically use any unsupervised method with a GAT encoder, e.g., Deep Graph Infomax (examples/infomax.py) or link prediction (examples/link_pred.py). If you only care about finding the importance of a node's neighbors for a given prediction, you might be also interested in GNNExplainer (examples/gnn_explainer.py), which tries to find a subset of meaningful edges that do not change the outcome of predictions.

Thank you for your reminding, I will be careful not to repeat the post. How to extract the weight of Deep Graph Infomax?

rusty1s commented 3 years ago

In case you have a GAT encoder, you can directly pass return_attention_weights=True to its forward call, see here.

wanglu2014 commented 3 years ago

In case you have a GAT encoder, you can directly pass return_attention_weights=True to its forward call, see here.

Yes, GAT related method always return attention coefficient. However, DGI parameter has no option on return weight and it source code have not cite GAT code. I have known that DGI have 4 weights (in dgl inplement) such as encoder.conv.layers.0.weight, encoder.conv.layers.0._activation.weight,encoder.conv.layers.1.weight,discriminator.weight, and I just need one most essential weight. Might I misunderstand your idea, could you kindly suggest how to extract attention weight for DGI model?

rusty1s commented 3 years ago

The idea is to utilize GATConv instead of GCNConv in the DGI encoder model, see here. Then, you can do:

def forward (self, x, edge_index):
    x, att = self.conv(x, edge_index)
    x = self.prelu(x)
    return x, att