datamol-io / graphium

Graphium: Scaling molecular GNNs to infinity.
https://graphium-docs.datamol.io/
Apache License 2.0
190 stars 12 forks source link

Implement Transformer-M from GPS++ #172

Closed DomInvivo closed 1 year ago

DomInvivo commented 1 year ago

https://arxiv.org/abs/2210.01765

Just the 3D input masking stuff

DomInvivo commented 1 year ago

To make it work, implement the part highlighted in Yellow in the image image

Change the ATTENTION_LAYERS_DICT in file goli.nn.pyg_layers.gps_pyg to add the option transformer-M.

IMPORTANT: Make sure that you use the MultiheadAttentionMup, not the MultiheadAttention in the implementation. This will allow muTransfer to work as expected.