ORNL / HydraGNN

Distributed PyTorch implementation of multi-headed graph convolutional neural networks
BSD 3-Clause "New" or "Revised" License
68 stars 29 forks source link

Implementation of DimeNet++ #184

Closed JustinBakerMath closed 1 year ago

JustinBakerMath commented 1 year ago

Introduces the key components of the DimeNet++ model to the HydraGNN library by modifying the PyG implementation.

The core addition is the DIMEStack which is inherits from Base to provide convolutional layers on request using the _get_conv() method.

Notice: Because the PyG implementations use the glorot initialization, the construction of the hidden dimension must be larger than 1 to avoid division by zero. Only the input_dim and output_dim are provided to _get_conv(), without constraint on the dimension of the convolution. Therefore, DIMEStack will use the input_dim if larger than 1 or the output_dim otherwise. If neither is larger than one, then the assertion will fail.

The DIMEStack also computes the necessary convolutional arguments, being rbf: (tensor) the radial basis function output. sbf: (tensor) the spherical basis function output. i: The message passing target. j: The message passing source. idx_kj: A portion of the bond triple. idx_ji: The second component of the bond triple.

Notice: DimeNet uses the triplet() method to compute the bond angles between triplets of nodes. This information is computed on every call but there is room for computational speed ups by computing once and then updating iteratively.

The stack modifies the Embedding , the baseline atomic feature embedding is removed and embedding is handled by HydraGNN.

Eight hyperparameters are introduced.

Other updates include documentation and testing.