Closed JustinBakerMath closed 1 year ago
@pzhanggit @jychoi-hpc I would like to give you some context behind this PR.
Justin Baker @JustinBakerMath is a graduate student at the University of Utah. He works with me and Cory Hauck on an ASCR project. One of the goals of the ASCR project is to develop message passing layers for graph neural networks. My contribution in the project is to make sure that the message passing layers developed in the projects are included and supported inside HydraGNN. :-)
Different message passing layers require different inputs. For instance, all the layers we currently support in HydraGNN have a signature of this type:
conv(data.x, data.edge_index)
However, there are also other types of convolutional layers (e.g., SchNet) with a different signature. For instance, the signature of the conv
method in SchNet is
conv(z, data.pos)
where z
is the atom number, and data.pos are the XYZ coordinates of the atoms.
This PR aims at generalizing the interface of the Base
class inside HydraGNN to include a broader class of message passing layers.
Separate the handling of convolutional arguments generated by the data object. The function
_conv_args
preserves the utility ofuse_edge_attr
.Furthermore, the inheritance of the convolutional stacks allows
data.pos
to be extracted as desired e.g. SchNet.