automl / amltk

A build-it-yourself AutoML Framework
https://automl.github.io/amltk/
BSD 3-Clause "New" or "Revised" License
68 stars 6 forks source link

feat: CVEarlyStopping #254

Closed eddiebergman closed 9 months ago

eddiebergman commented 9 months ago

Alright, another big one.

Major feature was to implement CVEarlyStopping. This will need to be documented once the optimization documentation gets updated. There also needs to be some default variants select-able by keyword but for now that's the subject of experimentation.


class CVEarlyStopper:
    def __init__(self, metric: Metric, threshold: float):
        super().__init__()
        self.threshold = threshold
        self.metric = metric

    def update(self, report: Trial.Report) -> None:
        pass # Normally you would update w.r.t. a finished trial

    def should_stop(self, info: CVEvaluation.FoldInfo) -> bool:
        return info.scores[self.metric.name] < self.threshold

metric = Metric("accuracy", minimize=False)
evaluator = CVEvaluation(_X, _y)
early_stopper = CVEarlyStopper(metric=metric, threshold=0.8)

history = mlp_classifier.optimize(
    target=evaluator.fn,
    metric=metric,
    on_trial_exception="continue",  # Seems required to prevent early stopping raising

    # The primary job of the plugin is to establish a comm link between the worker and
    # the master process and use the class above to handle what to do.
    plugins=[evaluator.cv_early_stopping_plugin(strategy=early_stopper)]
)

There were some larger updates that needed to be done to enable this, namely:

To illustrate the more explicit callback method, the above way of doing things is more or less equivalent to the following:

scheduler = Scheduler.with_processes(1)
evaluator = CVEvaluation(_X, _y)
metric = Metric("accuracy", minimize=False)

# Notably **nothing** passed in to `cv_early_stopping_plugin()` now, nothing will
# get called if we don't listen to the `@fold-evaluated` event.
task = scheduler.task(evaluator.fn, plugins=[evaluator.cv_early_stopping_plugin())

# Method one, listen to the `@fold-evaluated` callback and return what to do

@task.on("fold-evaluated")
def should_stop(self, info: CVEvaluation.FoldInfo) -> bool:
    return info.scores[metric.name] < 0.8

# Method two, this explicitly uses the `Comm` and `Msg` that happens underneath the hood.
# This is what `cv_early_stopping_plugin` is doing using the users object.

@task.on("comm-request")
def should_stop_2(msg: Msg.Data) -> None:
    fold_info: CVEvaluation.FoldInfo = msg.data
    if info.scores[metric.name] < 0.8:
        msg.respond(True)
    else:
        msg.respond(False) 

history = mlp_classifier.optimize(
    target=task,
    metric=metric,
    on_trial_exception="continue",
    scheduler=scheduler,
)

The notable change was that the callback of @task.on("fold-evaluated") returned a value, shielding the user from the Comm implementation detail. This required updating the Event system to allow for returned values from handlers, necessitating the updated signature of all existing events in amltk. Hence a lot of changes.


Last few changes were around code simplification while I was going through implementing this.

# Before
@task.on("on_result")
def f(...): ...

@task.on("on_future_submitted")
def f(...): ...

# After
@task.on("result")
def f(...): ...

@task.on("future-submitted")
def f(...): ...

# Note comm message using hyphens
@task.on("comm-message")
def f(...): ...

This is backwards breaking, but given we advertise using task.on_result as a shorthand that should work the same, I'm okay with making the breaking change sooner rather than later. I think we should still advertise the on_<event> where it's possible as it provides a level of type safety we just can't do with strings.

eddiebergman commented 9 months ago

Docs are failing and rightfully so, I'll update them soonish

eddiebergman commented 9 months ago

@LennartPurucker Implemented changes to address your comments.