mle-infrastructure / mle-toolbox

Lightweight Tool to Manage Distributed ML Experiments 🛠
https://mle-infrastructure.github.io/mle_toolbox/toolbox/
MIT License
3 stars 1 forks source link

Simplified experiment setup w/o confusing function imports #46

Closed RobertTLange closed 3 years ago

RobertTLange commented 3 years ago

Currently the main code executed to generate the experiment results has to be python-based. Furthermore, we have to use toolbox utilities for fixing random seeds and the 'archiving' of results relies on DeepLogger. All Python-based experiments should start by

from mle_toolbox.utils import get_configs_ready, DeepLogger, set_random_seeds

Rename/structure get_configs_ready and DeepLogger in a way that is more intuitive/clear. This could for example involve a class called MLExperimenter, which takes care of all of these things at initiation. E.g.:

class MLExperimenter(object):
    def __init__(self,
                 config_fname: str="configs/base_config.json",
                 auto_setup: bool=True):
        ''' Load the job configs for the MLE experiment. '''
        # Load the different configurations for the experiment.
        train_config, net_config, log_config = get_configs_ready(config_fname)
        self.train_config = train_config
        self.net_config = net_config
        self.log_config = log_config

        # Make initial setup optional so that configs can be modified ad-hoc
        if auto_setup:
            self.setup()

    def setup(self):
        ''' Set the random seed & initialize the logger. '''
        # Set the random seeds for all random number generation
        set_random_seeds(self.train_config.seed_id)

        # Initialize the logger for the experiment
        self.mle_log = MLE_Logger(**self.log_config)

    def update_log(self,
                   clock_tick: list,
                   stats_tick: list,
                   model=None,
                   plot_to_tboard=None,
                   save=False):
        ''' Update the MLE_Logger instance with stats, model params & save. '''
        self.mle_log.update(clock_tick, stats_tick, model,
                                          plot_to_tboard, save)

Furthermore, I think we want to rename all net_config instances to model_config in order to keep things more general and to ensure that this is consistent with how we store checkpoints with DeepLogger/MLE_Logger. This will reduce all imports of the user to this single one:

from mle_toolbox import MLExperimenter

def main(mle_experiment):
    ...

if __name__ == "__main__":
    mle_experiment = MLExperimenter()
    main(mle_experiment)

In principle this should also be possible using a decorator.

RobertTLange commented 3 years ago

Addressed in #49.