choderalab / modelforge

Infrastructure to implement and train NNPs
https://modelforge.readthedocs.io/en/latest/
MIT License
9 stars 4 forks source link

infrastructure work #34

Closed wiederm closed 8 months ago

wiederm commented 8 months ago

This PR introduces the new API for NNPs that @chrisiacovella and I discussed with @jchodera .

The updated version contains an AbstractBaseNNP that mimics the previous BaseNNP class. It has two methods: forward implemented as an abstract method providing hints for the preferred input signature Dict[str, torch.Tensor] and documentation, as well as input_checks that currently checks for the two keys that need to be present, atomic_numbers and positions, as well as for the dimensionality of the torch.Tensor.

The former BaseNNP is now a more specialized implementation for an NNP, which currently splits the calculation of the neighbor pairs (which is now in the forward method), and calls a NNP specific _forward method that will be responsible for perform the data representation, atom interaction and the calculation of the atom-wise contributions and accumulation in global prediction.

The input signature for the AbstractBaseNNP (and following NNPs)

# Abstract Base Class for every NNP
class AbstractBaseNNP(nn.Module, ABC):
    """
    Abstract base class for neural network potentials.

    Methods
    -------
    forward(inputs: Dict[str, torch.Tensor]) -> torch.Tensor
        Abstract method for neighbor list calculation and forward pass in neural network potentials.
    _forward(inputs: Dict[str, torch.Tensor]) -> torch.Tensor
       Abstract method for forward pass in NNP
    input_checks(inputs: Dict[str, torch.Tensor])
        Perform input checks to validate the input dictionary.
    """

    def __init__(self):
        """
        Initialize the AbstractBaseNNP class.
        """
        super().__init__()

    @abstractmethod
    def forward(self, inputs: Dict[str, torch.Tensor]) -> torch.Tensor:
        """
        Abstract method for neighbor list calculation and forward pass in neural network potentials.

        Parameters
        ----------
        inputs : Dict[str, torch.Tensor]
            - 'atomic_numbers', shape (nr_systems, nr_atoms), 0 indicates non-interacting atoms that will be masked
            - 'total_charge', shape (nr_systems, 1)
            - 'positions', shape (n_atoms, 3)
            - 'boxvectors', shape (3, 3)

        Returns
        -------
        torch.Tensor
            Calculated output; shape is implementation-dependent.

        """
        pass

    @abstractmethod
    def _forward(self, inputs: Dict[str, torch.Tensor]):
        """
        Abstract method for forward pass in neural network potentials.
        This method is called by `forward`.

        Parameters
        ----------
        inputs: Dict[str, torch.Tensor]
            - 'pairlist': Dict[str, torch.Tensor], contains:
                - pairlist, shape (n_paris,2)
                - r_ij, shape (n_pairs, 1)
                - d_ij, shape (n_pairs, 3)
                - 'atomic_subsystem_index' (optional), shape n_atoms

        """
        pass
codecov-commenter commented 8 months ago

Codecov Report

Merging #34 (62195e0) into main (51b1322) will increase coverage by 4.28%. The diff coverage is 92.30%.

Additional details and impacted files