We should be able to avoid this by repetition by defining the model_cls attribute like:
from typing import get_type_hints
from abc import ABC
class BasicNeuralNet(ABC):
@property
def model_cls(self):
return get_type_hints(self)["model"]
This preserves type hints in the IDE and removes the repetition, though there is no mechanism AFAIK to enforce that a subclass has defined a type hint. This also means model_cls is set at runtime, so the sanity check on model_cls in load_from_checkpoint needs to happen after the BasicNeuralNet is initialized (but before the checkpoint weights are loaded in).
Note that this change means that the model typehint is now required if using the functionality in BasicNeuralNet, namely:
Automatic key/type checking of config in __init__
Creation of the model in make_training_model
Checkpointing save_checkpoint and load_from_checkpoint
ML4OPF currently requires that subclasses of
BasicNeuralNet
define themodel_cls
attribute. For example, theDCPBasicNeuralNet
has:https://github.com/AI4OPT/ML4OPF/blob/461269ae66f5307787042a7b4c2b8e9740a022f0/ml4opf/models/basic_nn/dcp_basic_nn.py#L46-L48
We should be able to avoid this by repetition by defining the
model_cls
attribute like:This preserves type hints in the IDE and removes the repetition, though there is no mechanism AFAIK to enforce that a subclass has defined a type hint. This also means
model_cls
is set at runtime, so the sanity check onmodel_cls
inload_from_checkpoint
needs to happen after theBasicNeuralNet
is initialized (but before the checkpoint weights are loaded in).Note that this change means that the
model
typehint is now required if using the functionality inBasicNeuralNet
, namely:config
in__init__
make_training_model
save_checkpoint
andload_from_checkpoint