mle-infrastructure / mle-toolbox

Lightweight Tool to Manage Distributed ML Experiments 🛠
https://mle-infrastructure.github.io/mle_toolbox/toolbox/
MIT License
3 stars 1 forks source link

Model/Network reload functionality #28

Closed RobertTLange closed 3 years ago

RobertTLange commented 3 years ago

Thanks to some recent changes both the meta_log and hyper_log store the path to the generated checkpoints/saved models. I want to add a simple function that reloads the model from the stored path and that is universal, i.e. works for torch, JAX and sklearn models. So something like this, but a bit more bullet-proof:

def reload_model_from_ckpt(ckpt_path, model_type, model=None):
    if model_type == "torch":
        checkpoint = torch.load(ckpt_path, map_location='cpu')
        model.load_state_dict(checkpoint)
    elif model_type in ["jax", "sklearn"]:
        with open(ckpth_path, 'rb') as fid:
            model = pickle.load(fid)
    return model
RobertTLange commented 3 years ago

First try in 01ef796. What is still missing? Maybe some tests for JAX and sklearn that everything works as expected.

RobertTLange commented 3 years ago

Added example for JAX VAE in a266366.