recursionpharma / gflownet

GFlowNet library specialized for graph & molecular data
MIT License
198 stars 37 forks source link

Fix GraphTransformer forward function #97

Closed hohyun312 closed 1 year ago

hohyun312 commented 1 year ago

I think I found inconsistency in the GraphTransformer doc string and the implementation of it, which are located at gflownet/models/graph_transformer.py. I don't think it is a critical flaw, but may impact the performance, as it is not intended.

Here is what current doc string says:

The per node outputs are the concatenation of the final (post graph-convolution) node embeddings and of the final virtual node embedding of the graph each node corresponds to.

The per graph outputs are the concatenation of a global mean pooling operation, of the final virtual node embeddings, and of the conditional information embedding.

And here is the current implementation of GraphTransformer:

1. class GraphTransformer(nn.Module):
2.     ...
3.     def forward(self, g: gd.Batch, cond: torch.Tensor):
4.         ...
5.         glob = torch.cat([gnn.global_mean_pool(o[: -c.shape[0]], g.batch), o[-c.shape[0] :]], 1)
6.         o_final = torch.cat([o[: -c.shape[0]]], 1)
7.         return o_final, glob

There are two problems:

For the first option (fix doc string), the code should be fixed something like this:

class GraphTransformer(nn.Module):
    ...
    def forward(self, g: gd.Batch, cond: torch.Tensor):
        ...
        n_final = o[: -c.shape[0]] # final node embeddings (without virtual nodes)
        v_final = o[-c.shape[0] :] # final virtual node embeddings
        glob = torch.cat([gnn.global_mean_pool(n_final, g.batch), v_final], 1)
        o_final = torch.cat(
            [n_final, v_final.repeat_interleave(torch.bincount(g.batch), dim=0)], 1
        )
        return o_final, glob

And the corresponding doc string should be something like:

The per graph outputs are the concatenation of a global mean pooling operation, of the final node embeddings, and of the final virtual node embeddings.

bengioe commented 1 year ago

Hi @hohyun312, thanks for opening this issue. You're right that the docstring does not match the code, this is probably due to me iterating through the architecture and not updating the doc.

You're welcome to open a PR that fixes the docstring, or possibly makes the concatenation of the conditional information embedding (c) an optional flag. If not I'll get to it eventually :)

bengioe commented 1 year ago

Fixed by #104