Graph-COM / GSAT

[ICML 2022] Graph Stochastic Attention (GSAT) for interpretable and generalizable graph learning.
https://arxiv.org/abs/2201.12987
MIT License
162 stars 21 forks source link

directed edge weights on undirected graphs #5

Closed simoons95 closed 2 years ago

simoons95 commented 2 years ago

Hello,

First of all, thank you for your paper and your code, it is a pleasure to work with it.

However, I have a question about the following line : https://github.com/Graph-COM/GSAT/blob/ea900dbe8d27fa64c30f2fe46ab8b5a68ef719ca/src/run_gsat.py#L79 When I run it, this line does not seem to do anything more than edge_att = (att + att) / 2. As a result, edge weights are different depending on the direction of the edge (0 to 1 != 1 to 0). Have I missed anything?

siqim commented 2 years ago

Thanks a lot for spotting and reporting this issue. After checking the code quickly, I find this is indeed a bug caused by PR #3. These issues are caused by edge_index that are not properly sorted, and I will fix this soon. Thank you very much!

simoons95 commented 2 years ago

Can it really happen that input indices are not sorted (reason of PR3), given they come from a dataloader?

siqim commented 2 years ago

If I remember correctly, I did PR3 because I found data.edge_index is not sorted for dataset mutag though it's from a dataloader, i.e., it gives something like [[0, 1, 0, 2, 0], [1, 0, 2, 0, 3]], where the src tensor should be [0, 0, 0, 1, 2] if it's sorted.

siqim commented 2 years ago

I just fixed this issue by PR #7, and now the code should work properly. Thanks again for spotting this issue, and feel free to let us know if you encounter any more issues!

simoons95 commented 2 years ago

I see you go to cpu but never come back to gpu, which may lead to some bugs. Maybe the following function could help:

from torch_geometric.utils import sort_edge_index

def reorder_like(from_edge_index, to_edge_index, values):
    from_edge_index, values = sort_edge_index(from_edge_index, values)
    ranking_score = to_edge_index[0] * (to_edge_index.max()+1) + to_edge_index[1]
    ranking = ranking_score.argsort().argsort()
    if not (from_edge_index[:, ranking] == to_edge_index).all():
        raise ValueError("Edges in from_edge_index and to_edge_index are different, impossible to match both.")
    return values[ranking]

You can use it like this:

                trans_idx, trans_val = transpose(data.edge_index, att, None, None, coalesced=False)
                trans_val_perm = reorder_like(trans_idx, data.edge_index, trans_val)
                edge_att = (att + trans_val_perm) / 2
siqim commented 2 years ago

Thanks a lot for the suggestion! I created a new PR #8 and updated the code as you suggested. :)