Closed simoons95 closed 2 years ago
Thanks a lot for spotting and reporting this issue. After checking the code quickly, I find this is indeed a bug caused by PR #3. These issues are caused by edge_index
that are not properly sorted, and I will fix this soon. Thank you very much!
Can it really happen that input indices are not sorted (reason of PR3), given they come from a dataloader?
If I remember correctly, I did PR3 because I found data.edge_index
is not sorted for dataset mutag
though it's from a dataloader, i.e., it gives something like [[0, 1, 0, 2, 0], [1, 0, 2, 0, 3]]
, where the src tensor should be [0, 0, 0, 1, 2]
if it's sorted.
I just fixed this issue by PR #7, and now the code should work properly. Thanks again for spotting this issue, and feel free to let us know if you encounter any more issues!
I see you go to cpu
but never come back to gpu
, which may lead to some bugs.
Maybe the following function could help:
from torch_geometric.utils import sort_edge_index
def reorder_like(from_edge_index, to_edge_index, values):
from_edge_index, values = sort_edge_index(from_edge_index, values)
ranking_score = to_edge_index[0] * (to_edge_index.max()+1) + to_edge_index[1]
ranking = ranking_score.argsort().argsort()
if not (from_edge_index[:, ranking] == to_edge_index).all():
raise ValueError("Edges in from_edge_index and to_edge_index are different, impossible to match both.")
return values[ranking]
You can use it like this:
trans_idx, trans_val = transpose(data.edge_index, att, None, None, coalesced=False)
trans_val_perm = reorder_like(trans_idx, data.edge_index, trans_val)
edge_att = (att + trans_val_perm) / 2
Thanks a lot for the suggestion! I created a new PR #8 and updated the code as you suggested. :)
Hello,
First of all, thank you for your paper and your code, it is a pleasure to work with it.
However, I have a question about the following line : https://github.com/Graph-COM/GSAT/blob/ea900dbe8d27fa64c30f2fe46ab8b5a68ef719ca/src/run_gsat.py#L79 When I run it, this line does not seem to do anything more than
edge_att = (att + att) / 2
. As a result, edge weights are different depending on the direction of the edge (0 to 1 != 1 to 0). Have I missed anything?