Open mattpitkin opened 4 months ago
Can you give us an idea of how you imagine this functionality would work programmatically?
Maybe the Readme example is a good start
rb = ReducedBasis()
rb.fit(training_set = training_set,
parameters = parameters
physical_points = x_set)
I'd think something along the lines of:
rb = ReducedBasis(
checkpoint_time_interval=1000,
checkpoint_dir="/home/user/checkpoints",
)
rb.fit(
training_set=training_set,
parameters=parameters,
physical_points=x_set,
)
where if checkpoint_time_interval
is a float/int giving the time in seconds between automatic checkpoints and checkpoint_dir
is the directory to which the checkpoint file is output (or read in from if resuming). You could also have a checkpoint_iter_interval
that specifies the number of greedy loop iterations to use between checkpoints rather than a time interval.
You could have something like:
from time import time
from pathlib import Path
class ReducedBasis:
...
def __init__(
self,
index_seed_global_rb=0,
lmax=0,
nmax=np.inf,
greedy_tol=1e-12,
normalize=False,
integration_rule="riemann",
checkpoint_time_interval=None,
checkpoint_iter_interval=None,
checkpoint_dir=None,
checkpoint_file_label="reduced_basis.pkl",
) -> None:
...
self.checkpoint = False # default to no checkpointing
self.checkpoint_iter_interval = checkpoint_iter_interval
self.checkpoint_time_interval = checkpoint_time_interval
if checkpoint_dir is not None:
# perform checkpointing if directory is specified
self.checkpoint = True
self.checkpoint_dir = Path(checkpoint_dir)
# make checkpoint dir if it doesn't exist
if not self.checkpoint_dir.is_dir():
self.checkpoint_dir.mkdir(parents=True)
# check for an existing checkpoint file
self.checkpoint_file = self.checkpoint_dir / checkpoint_file_label
...
def _fit(
...
):
...
t0 = None
itercount = 0
if self.checkpoint:
if self.checkpoint_time_interval is not None:
t0 = time()
# check for existing checkpoint file and load if present
if self.checkpoint and self.checkpoint_file.is_file():
# load file (probably a pickle file containing required information)
...
# greedy loop
while sigma > self.greedy_tol and self.nmax > nn + 1:
...
if self.checkpoint:
if t0 is not None:
t1 = time()
if t1 - t0 >= self.checkpoint_time_interval:
# write out checkpoint file
...
t0 = time() # reset timer
else:
itercount += 1
if itercount == self.checkpoint_iter_interval:
# write out checkpoint file
...
itercount = 0
It would be useful to have a way of saving the state/checkpointing during the reduced basis generation, and equivalently restarting the basis generation from the checkpointed state. This would be helpful for bases that take a long time to generate and if the generation process gets killed before completion. Being able to specify a path to the checkpoint file when setting up the
ReducedBasis
class would be useful.See also https://github.com/aaronuv/arby/issues/10.