mle-infrastructure / mle-toolbox

Lightweight Tool to Manage Distributed ML Experiments 🛠
https://mle-infrastructure.github.io/mle_toolbox/toolbox/
MIT License
3 stars 1 forks source link

Improve DeepLogger - Checkpoints + tboard + model formats #12

Closed RobertTLange closed 3 years ago

RobertTLange commented 3 years ago

Point 1: Currently we only have two options to store checkpoints:

  1. Store only most recent checkpoint. Always overwrite .pt file.
  2. Store checkpoint every k steps. Give ckpt an index (e.g. 0, 1, 2, 3).

a) Allow for the option to store top k checkpoints based on a logged variable that measures the performance. b) Is there a better way to index checkpoints in order to figure out when the snapshot was taken? Store that in corresponding log .hdf5 file?

Point 2: Network stats logging right now only works if network is torch based. Let's add an option to differentiate between torch and JAX. E.g. based on whether network is a torch.nn module or a FrozenDict for JAX.

RobertTLange commented 3 years ago

Ideally the model logging/storage is even more flexible and also allows for simpler sklearn style setups. The log_config dictionary could include a model_type variable that is in [sklearn, torch, tensorflow, jax] and then stores/logs the different statistics/checkpoints/model states accordingly.

TIL: sklearn models/trained classifiers can be simply stored via pickle: https://stackoverflow.com/questions/10592605/save-classifier-to-disk-in-scikit-learn

RobertTLange commented 3 years ago

Addressed in PR #29.