DeepGraphLearning / torchdrug

A powerful and flexible machine learning platform for drug discovery
https://torchdrug.ai/
Apache License 2.0
1.44k stars 200 forks source link

G2G model labeling nodes and edges #129

Closed DimGorr closed 2 years ago

DimGorr commented 2 years ago

Implementing the G2G model in dataset USPTO-50k you have functions _get_difference and _get_reaction_center and in the second one, you define edge_labels which you use for the training center identification part in retrosynthesis. The problem is that in the first one the graph is converted to a directed one so each edge counts once, however, in the second function it's not, and as a result, the second function assigns 1 to edge [a,b] if it was added but doesn't add 1 to the edge [b, a]. Let me know if I misunderstood smth. If I'm correct I would suggest adding to the second function this piece of code when defining edge_label

a,b=edge_added[0]
pattern= torch.cat([torch.tensor([[b,a]]), any], dim=-1)
index, num_match = product.match(pattern)
assert num_match.item() == 1
edge_label[index] = 1
DimGorr commented 2 years ago

sorry I didn't notice it is actually converted to a directed graph in the target function in G2G itself. so everything is fine