This PR cleans up the interface of the BaseNNP class, removes the Abstract Base class, and transforms the overall structure of the forward call to
class BaseNNP():
def forward():
# adjust the dtype of the input tensors to match the model parameters
self._set_dtype()
# perform input checks
inputs = self._input_checks(inputs) # implemented in BaseNNP
# prepare the input for the forward pass
inputs = self.prepare_inputs(inputs) # implemented in the NNP subclasses
# perform the forward pass implemented in the subclass
output = self._forward(inputs) # implemented in the NNP subclasses
return self._readout(output) # self._redout is implemented in the NNP subclass
This makes further development much more accessible since we make neither assumptions about the input data structure nor the readout function (or target).
Further improvements include:
renaming ambiguous naming, now every differentiable model has a _module ending. Previously, it was unclear if the cutoff variable is the cutoff module that implements a cutoff smoothing function or if it is the cutoff radius. Now the cutoff function is marked as `cutoff_module).
there was ambiguity in the naming; sometimes we used n_features, other times, nr_features. Now always nr prefix.
Description
This PR cleans up the interface of the BaseNNP class, removes the Abstract Base class, and transforms the overall structure of the forward call to
This makes further development much more accessible since we make neither assumptions about the input data structure nor the readout function (or target).
Further improvements include:
_module
ending. Previously, it was unclear if thecutoff
variable is the cutoff module that implements a cutoff smoothing function or if it is thecutoff
radius. Now the cutoff function is marked as `cutoff_module).n_features,
other times,nr_features.
Now alwaysnr
prefix.Status