replicate / keepsake

Version control for machine learning
https://keepsake.ai
Apache License 2.0
1.65k stars 72 forks source link

Animated plots #397

Open bfirsh opened 3 years ago

bfirsh commented 3 years ago

Why

It'd be neat if plots automatically updated so that you could watch the results of training live.

How

It probably depends on #292, and then would require some communication between Python and JS. It also affects the semantics of how the Python API works, a bit. Experiments/checkpoints are currently loaded once and then are static.

See also

VIVelev commented 3 years ago

Hey, I created a simple script that visualizes the checkpoints' metrics live, using streamlit.

Here it is:

from collections import defaultdict
from time import sleep
from typing import Generator, Optional

import keepsake
from keepsake.checkpoint import CheckpointList
from keepsake.experiment import Experiment
import numpy as np
import streamlit as st

def get_data(chkpts: CheckpointList, metric: Optional[str] = None) -> list[float]:
    if metric is None:
        metric = chkpts.primary_metric()

    data = []
    for chk in chkpts:
        if chk.metrics and metric in chk.metrics:
            data.append(chk.metrics[metric])
        else:
            data.append(None)

    return data

def experiments() -> Generator[Experiment, None, None]:
    latest_chk: dict[str, str] = defaultdict(lambda: "")
    while True:
        for e in keepsake.experiments.list():
            if (chk := e.latest()) and chk.id != latest_chk[e.id]:
                latest_chk[e.id] = chk.id
                yield e

        keepsake.default_project._daemon().cleanup()
        del keepsake.default_project._daemon_instance
        keepsake.default_project._daemon_instance = None

        sleep(0.5)

if __name__ == "__main__":
    plots: dict[str, "st.line_chart"] = {}  # type: ignore

    for e in experiments():
        if e.id in plots.keys():
            plots[e.id].add_rows(np.array([e.latest().metrics[e.primary_metric()]]))
        else:
            plots[e.id] = st.line_chart(get_data(e.checkpoints))

It can track across both checkpoints and experiments (i.e. if a new experiment is created it will automatically create a new plot). It all works pretty well, however I was wondering if there is a better way to pull up-to-date information about the project. Unfortunately, experiment.refresh() did not work on my machine, but even if it did, it wouldn't have been able to handle newly created experiments as far as I am aware (Am I right?). Also, in general, the python script was not getting the latest checkpoints. So, I ended up resetting the _daemon_instance, however this doesn't seem to be the best solution. I suppose constantly terminating and starting processes is quite expensive, right? What do you think?