nmichlo / disent

🧶 Modular VAE disentanglement framework for python built with PyTorch Lightning ▸ Including metrics and datasets ▸ With strongly supervised, weakly supervised and unsupervised methods ▸ Easily configured and run with Hydra config ▸ Inspired by disentanglement_lib
https://disent.michlo.dev
MIT License
122 stars 18 forks source link

Adopt a Lightning-Flash Style API for Frameworks #19

Closed nmichlo closed 3 years ago

nmichlo commented 3 years ago

The current instantiation of Frameworks is terrible, requiring a two callables. One that returns a new optimizer instance and one that returns a new model instance. This is not good for tracking hyper-parameters and overall usability.

Instead opt for a Lightning-Flash style API:

# from: https://github.com/PyTorchLightning/lightning-flash/blob/master/flash/image/classification/model.py
def __init__(
    self,
    ...
    backbone: Union[str, Tuple[nn.Module, int]] = "resnet18",
    backbone_kwargs: Optional[Dict] = None,
    ...
    optimizer: Union[Type[torch.optim.Optimizer], torch.optim.Optimizer] = torch.optim.Adam,
    optimizer_kwargs: Optional[Dict[str, Any]] = None,
    scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None,
    scheduler_kwargs: Optional[Dict[str, Any]] = None,
    ...
    learning_rate: float = 1e-3,
    ...
):
nmichlo commented 3 years ago

This has mostly been implemented in new development versions.