This PR introduces the new API for NNPs that @chrisiacovella and I discussed with @jchodera .
The updated version contains an AbstractBaseNNP that mimics the previous BaseNNP class. It has two methods:
forward implemented as an abstract method providing hints for the preferred input signature Dict[str, torch.Tensor] and documentation, as well as input_checks that currently checks for the two keys that need to be present, atomic_numbers and positions, as well as for the dimensionality of the torch.Tensor.
The former BaseNNP is now a more specialized implementation for an NNP, which currently splits the calculation of the neighbor pairs (which is now in the forward method), and calls a NNP specific _forward method that will be responsible for perform the data representation, atom interaction and the calculation of the atom-wise contributions and accumulation in global prediction.
The input signature for the AbstractBaseNNP (and following NNPs)
# Abstract Base Class for every NNP
class AbstractBaseNNP(nn.Module, ABC):
"""
Abstract base class for neural network potentials.
Methods
-------
forward(inputs: Dict[str, torch.Tensor]) -> torch.Tensor
Abstract method for neighbor list calculation and forward pass in neural network potentials.
_forward(inputs: Dict[str, torch.Tensor]) -> torch.Tensor
Abstract method for forward pass in NNP
input_checks(inputs: Dict[str, torch.Tensor])
Perform input checks to validate the input dictionary.
"""
def __init__(self):
"""
Initialize the AbstractBaseNNP class.
"""
super().__init__()
@abstractmethod
def forward(self, inputs: Dict[str, torch.Tensor]) -> torch.Tensor:
"""
Abstract method for neighbor list calculation and forward pass in neural network potentials.
Parameters
----------
inputs : Dict[str, torch.Tensor]
- 'atomic_numbers', shape (nr_systems, nr_atoms), 0 indicates non-interacting atoms that will be masked
- 'total_charge', shape (nr_systems, 1)
- 'positions', shape (n_atoms, 3)
- 'boxvectors', shape (3, 3)
Returns
-------
torch.Tensor
Calculated output; shape is implementation-dependent.
"""
pass
@abstractmethod
def _forward(self, inputs: Dict[str, torch.Tensor]):
"""
Abstract method for forward pass in neural network potentials.
This method is called by `forward`.
Parameters
----------
inputs: Dict[str, torch.Tensor]
- 'pairlist': Dict[str, torch.Tensor], contains:
- pairlist, shape (n_paris,2)
- r_ij, shape (n_pairs, 1)
- d_ij, shape (n_pairs, 3)
- 'atomic_subsystem_index' (optional), shape n_atoms
"""
pass
This PR introduces the new API for NNPs that @chrisiacovella and I discussed with @jchodera .
The updated version contains an AbstractBaseNNP that mimics the previous BaseNNP class. It has two methods:
forward
implemented as an abstract method providing hints for the preferred input signatureDict[str, torch.Tensor]
and documentation, as well asinput_checks
that currently checks for the two keys that need to be present,atomic_numbers
andpositions
, as well as for the dimensionality of thetorch.Tensor
.The former
BaseNNP
is now a more specialized implementation for an NNP, which currently splits the calculation of the neighbor pairs (which is now in theforward
method), and calls a NNP specific_forward
method that will be responsible for perform the data representation, atom interaction and the calculation of the atom-wise contributions and accumulation in global prediction.The input signature for the AbstractBaseNNP (and following NNPs)