ORNL / HydraGNN

Distributed PyTorch implementation of multi-headed graph convolutional neural networks
BSD 3-Clause "New" or "Revised" License
68 stars 29 forks source link

Separate handling of conv args #170

Closed JustinBakerMath closed 1 year ago

JustinBakerMath commented 1 year ago

Separate the handling of convolutional arguments generated by the data object. The function _conv_argspreserves the utility of use_edge_attr.

Furthermore, the inheritance of the convolutional stacks allows data.pos to be extracted as desired e.g. SchNet.

allaffa commented 1 year ago

@pzhanggit @jychoi-hpc I would like to give you some context behind this PR.

Justin Baker @JustinBakerMath is a graduate student at the University of Utah. He works with me and Cory Hauck on an ASCR project. One of the goals of the ASCR project is to develop message passing layers for graph neural networks. My contribution in the project is to make sure that the message passing layers developed in the projects are included and supported inside HydraGNN. :-)

Different message passing layers require different inputs. For instance, all the layers we currently support in HydraGNN have a signature of this type: conv(data.x, data.edge_index)

However, there are also other types of convolutional layers (e.g., SchNet) with a different signature. For instance, the signature of the conv method in SchNet is conv(z, data.pos) where z is the atom number, and data.pos are the XYZ coordinates of the atoms.

This PR aims at generalizing the interface of the Base class inside HydraGNN to include a broader class of message passing layers.