choderalab / modelforge

Infrastructure to implement and train NNPs
https://modelforge.readthedocs.io/en/latest/
MIT License
9 stars 4 forks source link

Refactoring #52

Closed wiederm closed 5 months ago

wiederm commented 5 months ago

Description

This PR cleans up the interface of the BaseNNP class, removes the Abstract Base class, and transforms the overall structure of the forward call to

class BaseNNP():

    def forward():
        # adjust the dtype of the input tensors to match the model parameters
        self._set_dtype()
        # perform input checks
        inputs = self._input_checks(inputs) # implemented in BaseNNP
        # prepare the input for the forward pass
        inputs = self.prepare_inputs(inputs) # implemented in the NNP subclasses
        # perform the forward pass implemented in the subclass
        output = self._forward(inputs) # implemented in the NNP subclasses
        return self._readout(output) # self._redout is implemented in the NNP subclass

This makes further development much more accessible since we make neither assumptions about the input data structure nor the readout function (or target).

Further improvements include:

Status