francocerino / scikit-reducedmodel

Reduced Order Models in a scikit-learn approach.
https://scikit-reducedmodel.readthedocs.io
MIT License
6 stars 1 forks source link

Save state/checkpoint during basis generation #3

Open mattpitkin opened 4 months ago

mattpitkin commented 4 months ago

It would be useful to have a way of saving the state/checkpointing during the reduced basis generation, and equivalently restarting the basis generation from the checkpointed state. This would be helpful for bases that take a long time to generate and if the generation process gets killed before completion. Being able to specify a path to the checkpoint file when setting up the ReducedBasis class would be useful.

See also https://github.com/aaronuv/arby/issues/10.

leliel12 commented 4 months ago

Can you give us an idea of how you imagine this functionality would work programmatically?

Maybe the Readme example is a good start

rb = ReducedBasis()
rb.fit(training_set = training_set,
       parameters = parameters
       physical_points = x_set)
mattpitkin commented 4 months ago

I'd think something along the lines of:

rb = ReducedBasis(
    checkpoint_time_interval=1000,
    checkpoint_dir="/home/user/checkpoints",
)
rb.fit(
    training_set=training_set,
    parameters=parameters,
    physical_points=x_set,
)

where if checkpoint_time_interval is a float/int giving the time in seconds between automatic checkpoints and checkpoint_dir is the directory to which the checkpoint file is output (or read in from if resuming). You could also have a checkpoint_iter_interval that specifies the number of greedy loop iterations to use between checkpoints rather than a time interval.

You could have something like:

from time import time
from pathlib import Path

class ReducedBasis:
    ...

    def __init__(
        self,
        index_seed_global_rb=0,
        lmax=0,
        nmax=np.inf,
        greedy_tol=1e-12,
        normalize=False,
        integration_rule="riemann",
        checkpoint_time_interval=None,
        checkpoint_iter_interval=None,
        checkpoint_dir=None,
        checkpoint_file_label="reduced_basis.pkl",
    ) -> None:
        ...

        self.checkpoint = False  # default to no checkpointing
        self.checkpoint_iter_interval = checkpoint_iter_interval
        self.checkpoint_time_interval = checkpoint_time_interval
        if checkpoint_dir is not None:
            # perform checkpointing if directory is specified
            self.checkpoint = True
            self.checkpoint_dir = Path(checkpoint_dir)

            # make checkpoint dir if it doesn't exist
            if not self.checkpoint_dir.is_dir():
                self.checkpoint_dir.mkdir(parents=True)

            # check for an existing checkpoint file
            self.checkpoint_file = self.checkpoint_dir / checkpoint_file_label

        ...

    def _fit(
        ...
    ):
        ...

        t0 = None
        itercount = 0
        if self.checkpoint:
            if self.checkpoint_time_interval is not None:
                t0 = time()

        # check for existing checkpoint file and load if present
        if self.checkpoint and self.checkpoint_file.is_file():
                # load file (probably a pickle file containing required information)
                ...

        # greedy loop
        while sigma > self.greedy_tol and self.nmax > nn + 1:
            ...
            if self.checkpoint:
                if t0 is not None:
                    t1 = time()

                    if t1 - t0 >= self.checkpoint_time_interval:
                        # write out checkpoint file
                        ...
                        t0 = time()  # reset timer
                    else:
                        itercount += 1
                        if itercount == self.checkpoint_iter_interval:
                            # write out checkpoint file
                            ...
                            itercount = 0