Open abieler opened 9 months ago
Thanks, this is a nice start. A few comments:
No need to introduce the DotProductAttention
type, we can use the MultiHeadAttention from Flux. According to the table A.2-A.5 in the paper, multi-head attention is the preferred choice. We should have an nheads
argument in the constructor. See also https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.GPSConv.html for the kind of flexibility we could try to achieve (in several PRs).
Part of the contribution of the paper is the discussion of different types of embeddings. This package lacks many of these embeddings. I hope they will be added in the future but in any case, it is ok for this PR to only implement the layer.
I think the current order of operations is wrong, see comment https://github.com/CarloLucibello/GraphNeuralNetworks.jl/pull/355#discussion_r1438471782
BathcNorm should be used instead of LayerNorm
In the paper is not clear if we should apply a residual connection after the MLP. For figure D.1 it seems there is one, but there is none according to Eq. 11.
Thanks for the comments. I'll be going over the authors codebase
> paper
> pytorch
implementation for implementation details for the next version
i.e.
This is only first "mock" version of a GPSConv layer to see if we would want it in the Repo in that form.
Adds a
DotProductAttention
layer that usesNNlib.dot_product_attention()
Adds a
GPSConv
layerDotPRoductAttention
as global attention layerNot sure about the
GNNChain()
implementation, if it should stay where it is or move into the struct?JuliaFormatter() got a bit too greedy and made some changes here and there, I can revert those of course
Did not check for correctness of the implementation yet
Let me know what you think and I can adjust / keep going from here.
Close #351