automl / neps

Neural Pipeline Search (NePS): Helps deep learning experts find the best neural pipeline.
https://automl.github.io/neps/
Apache License 2.0
39 stars 11 forks source link

refactor(runtime): Partial update of state #93

Closed eddiebergman closed 2 months ago

eddiebergman commented 2 months ago

Ignore the branch name... Also this is WIP

On the way towards making an ask-and-tell interface for NePS, I needed to isolate some behaviors. As part of this, I implemented a new feature to the workers where they only read in and update the state they need to.

What this means is that NePS workers will maintain the history of configs they've evaluated and when it comes time to sample a new config, they check the directories for what's new and update accordingly, kind of like a git-diff. Overall effect is a worker only has to read in what the other workers have done in the meantime.

If you were to start a fresh worker, it will read in the entire history since it has nothing to begin with.


From a brief look, it seems to remove yaml loading entirely from py-spy and reduces file system overhead by a lot.


This is the meat and potatoes of it:

    def update_from_disk(self) -> None:  # noqa: C901, PLR0912, PLR0915
        """Update the shared state from disk."""
        trial_dirs = (p for p in self.paths.results_dir.iterdir() if p.is_dir())
        trials_on_disk = [TrialOnDisk.from_dir(p) for p in trial_dirs]

        for trial_on_disk in trials_on_disk:
            state = trial_on_disk.state()

            if state in (TrialOnDisk.State.SUCCESS, TrialOnDisk.State.ERROR):
                if trial_on_disk.id in self.evaluated_trials:
                    continue

                # It's been evaluated and we can move it out of pending
                self.pending_trials.pop(trial_on_disk.id, None)
                self.in_progress_trials.pop(trial_on_disk.id, None)

                raw_config, metadata, previous_config_id, result = trial_on_disk.load()

                # NOTE: Assuming that the previous one will always have been
                # evaluated, if there is a previous one.
                previous_report = None
                if previous_config_id is not None:
                    previous_report = self.evaluated_trials[previous_config_id]

                trial = Trial(
                    id=trial_on_disk.id,
                    config=raw_config,
                    pipeline_dir=trial_on_disk.pipeline_dir,
                    previous=previous_report,
                    time_sampled=metadata["time_sampled"],
                    metadata=metadata,
                )
                if isinstance(result, SuccessResult):
                    report = trial.success(result.results, time_end=metadata["time_end"])
                elif isinstance(result, ErrorResult):
                    report = trial.error(
                        result.err,
                        tb=result.tb,
                        time_end=metadata["time_end"],
                    )
                elif result is None:
                    raise RuntimeError(
                        "Result should not have been None, this is a bug!",
                        "Please report this to the developers with some sample code"
                        " if possible.",
                    )
                else:
                    raise TypeError(f"Unknown result type {type(result)}")

                self.evaluated_trials[trial_on_disk.id] = report

            elif state is TrialOnDisk.State.PENDING:
                assert trial_on_disk.id not in self.evaluated_trials
                if trial_on_disk.id in self.pending_trials:
                    continue

                raw_config, metadata, previous_config_id, result = trial_on_disk.load()

                # NOTE: Assuming that the previous one will always have been evaluated,
                # if there is a previous one.
                previous_report = None
                if previous_config_id is not None:
                    previous_report = self.evaluated_trials[previous_config_id]

                trial = Trial(
                    id=trial_on_disk.id,
                    config=raw_config,
                    pipeline_dir=trial_on_disk.pipeline_dir,
                    previous=previous_report,
                    time_sampled=metadata["time_sampled"],
                    metadata=metadata,
                )
                self.pending_trials[trial_on_disk.id] = trial

            elif state is TrialOnDisk.State.IN_PROGRESS:
                assert trial_on_disk.id not in self.evaluated_trials
                if trial_on_disk.id in self.in_progress_trials:
                    continue

                # If this was previously in the pending queue, jsut pop
                # it into the in progress queue
                previously_pending_trial = self.pending_trials.pop(trial_on_disk.id, None)
                if previously_pending_trial is not None:
                    self.in_progress_trials[trial_on_disk.id] = previously_pending_trial
                    continue

                # Otherwise it's the first time we saw it so we have to load it in
                raw_config, metadata, previous_config_id, result = trial_on_disk.load()

                # NOTE: Assuming that the previous one will always have been evaluated,
                # if there is a previous one.
                previous_report = None
                if previous_config_id is not None:
                    previous_report = self.evaluated_trials[previous_config_id]

                trial = Trial(
                    id=trial_on_disk.id,
                    config=raw_config,
                    pipeline_dir=trial_on_disk.pipeline_dir,
                    previous=previous_report,
                    time_sampled=metadata["time_sampled"],
                    metadata=metadata,
                )
                self.pending_trials[trial_on_disk.id] = trial

            elif state == TrialOnDisk.State.CORRUPTED:
                logger.warning(f"Removing corrupted trial {trial_on_disk.id}")
                try:
                    shutil.rmtree(trial_on_disk.pipeline_dir)
                except Exception as e:
                    logger.exception(e)

            else:
                raise ValueError(f"Unknown state {state} for trial {trial_on_disk.id}")

    @contextmanager
    def sync(self, *, lock: bool = False) -> Iterator[None]:
        """Sync up with what's on disk."""
        if lock:
            _poll, _timeout = get_shared_state_poll_and_timeout()
            ctx = partial(self.lock, poll=_poll, timeout=_timeout)
        else:
            ctx = nullcontext

        with ctx():
            self.update_from_disk()
            yield