recursionpharma / gflownet

GFlowNet library specialized for graph & molecular data
MIT License
211 stars 41 forks source link

Interpretation of edge_attr in FragMolBuildingEnvContext #58

Closed timgaripov closed 1 year ago

timgaripov commented 1 year ago

Hi!

Thanks for the great work on GFlowNets. I am going through the code, to understand and reproduce the multi-objective molecule generation experiments (from the Multi-objective GFlowNets paper).

I am a little bit confused by the interpretation of edge_attributes created by FragMolBuildingEnvContext.graph_to_Data

The edge_attr tensor is initialized as

edge_attr = torch.zeros((len(g.edges) * 2, self.num_edge_dim))

and then filled as follows

for i, e in enumerate(g.edges):
    ad = g.edges[e]
    a, b = e
    for n, offset in zip(e, [0, self.num_stem_acts]):
        idx = ad.get(f'{int(n)}_attach', -1) + 1
        edge_attr[i * 2, idx] = 1
        edge_attr[i * 2 + 1, idx] = 1
        ...

I do not completely understand how I should interpret the features edge_attr[i * 2, :] and edge_attr[i * 2 + 1, :] of the edge number i.

If I understand correctly, in the code above idx takes value 0 when the attribute f'{a}_attach' (f'{b}_attach') is not set. Otherwise, idx is the 1-based index of the stem of node a (b). In particular, if both attributes f'{a}_attach' and f'{b}_attach' are set to 0, then the edge_attr[i * 2, :] = [0, 1, 0, 0, 0, ... ].

Questions:

self.num_edge_dim = (most_stems + 1) * 2

...

edge_attr = torch.zeros((len(g.edges) * 2, self.num_edge_dim))

...

for i, e in enumerate(g.edges): ad = g.edges[e] a, b = e for n, offset in zip(e, [0, self.num_stem_acts]): idx = ad.get(f'{int(n)}_attach', -1) + 1 edge_attr[i 2, idx + offset] = 1 edge_attr[i 2 + 1, idx + offset] = 1 ...

bengioe commented 1 year ago

Hi @timgaripov, I think you're correct 🤔 by fixing a bug in the encoding in this PR and removing offset I actually introduced another.

For the fragment environment, the edges are actually directed, so we should actually encode the one hot of (a,b) for one edge and (b,a) for the other. I think something like this should be the correct loop:

for i, e in enumerate(g.edges):
    ad = g.edges[e]
    a, b = e
    for j, (n, offset) in enumerate(zip(e, [0, self.num_stem_acts])):
        idx = ad.get(f'{int(n)}_attach', -1) + 1
        edge_attr[i * 2, idx + self.num_stem_acts * j] = 1
        edge_attr[i * 2 + 1, idx + self.num_stem_acts * (1 - j)] = 1

Good catch! Thanks! If you'd like the attribution feel free to open a PR.

Ironically this may not make a big difference, in terms of performance; deep learning is magical that way. I can rerun our benchmarks to see if this breaks/improves anything.

timgaripov commented 1 year ago

@bengioe, thanks for the prompt response! I am glad that my understanding is correct. I just sent a PR for this.

I was also thinking that edge feature encoding probably wouldn't have a significant effect on the final result. However, taking into account the magical ways of deep learning, I wouldn't be too surprised to see any imaginable (or unimaginable) outcome :grin: