AI4OPT / ML4OPF

Machine Learning for Optimal Power Flow
MIT License
5 stars 4 forks source link

Redundant `model` typehint and `model_cls` attribute in `BasicNeuralNet` #4

Closed klamike closed 4 months ago

klamike commented 4 months ago

ML4OPF currently requires that subclasses of BasicNeuralNet define the model_cls attribute. For example, the DCPBasicNeuralNet has:

https://github.com/AI4OPT/ML4OPF/blob/461269ae66f5307787042a7b4c2b8e9740a022f0/ml4opf/models/basic_nn/dcp_basic_nn.py#L46-L48

We should be able to avoid this by repetition by defining the model_cls attribute like:

from typing import get_type_hints
from abc import ABC

class BasicNeuralNet(ABC):
    @property
    def model_cls(self):
        return get_type_hints(self)["model"]

This preserves type hints in the IDE and removes the repetition, though there is no mechanism AFAIK to enforce that a subclass has defined a type hint. This also means model_cls is set at runtime, so the sanity check on model_cls in load_from_checkpoint needs to happen after the BasicNeuralNet is initialized (but before the checkpoint weights are loaded in).

Note that this change means that the model typehint is now required if using the functionality in BasicNeuralNet, namely: