Major feature was to implement CVEarlyStopping. This will need to be documented once the optimization documentation gets updated. There also needs to be some default variants select-able by keyword but for now that's the subject of experimentation.
class CVEarlyStopper:
def __init__(self, metric: Metric, threshold: float):
super().__init__()
self.threshold = threshold
self.metric = metric
def update(self, report: Trial.Report) -> None:
pass # Normally you would update w.r.t. a finished trial
def should_stop(self, info: CVEvaluation.FoldInfo) -> bool:
return info.scores[self.metric.name] < self.threshold
metric = Metric("accuracy", minimize=False)
evaluator = CVEvaluation(_X, _y)
early_stopper = CVEarlyStopper(metric=metric, threshold=0.8)
history = mlp_classifier.optimize(
target=evaluator.fn,
metric=metric,
on_trial_exception="continue", # Seems required to prevent early stopping raising
# The primary job of the plugin is to establish a comm link between the worker and
# the master process and use the class above to handle what to do.
plugins=[evaluator.cv_early_stopping_plugin(strategy=early_stopper)]
)
There were some larger updates that needed to be done to enable this, namely:
This is implemented as a Comm.Plugin such that the worker process can communicate with the main process, where the CVEarlyStopper lives.
The communication through Comm relies on the worker request()'ing whether it should stop and the main process msg.respond(should_stop)'ing on what to do.
To shield the user from having to know about the details of how this is implemented, we would rather they implement a simple function and use it's return value to handle all of this.
To illustrate the more explicit callback method, the above way of doing things is more or less equivalent to the following:
scheduler = Scheduler.with_processes(1)
evaluator = CVEvaluation(_X, _y)
metric = Metric("accuracy", minimize=False)
# Notably **nothing** passed in to `cv_early_stopping_plugin()` now, nothing will
# get called if we don't listen to the `@fold-evaluated` event.
task = scheduler.task(evaluator.fn, plugins=[evaluator.cv_early_stopping_plugin())
# Method one, listen to the `@fold-evaluated` callback and return what to do
@task.on("fold-evaluated")
def should_stop(self, info: CVEvaluation.FoldInfo) -> bool:
return info.scores[metric.name] < 0.8
# Method two, this explicitly uses the `Comm` and `Msg` that happens underneath the hood.
# This is what `cv_early_stopping_plugin` is doing using the users object.
@task.on("comm-request")
def should_stop_2(msg: Msg.Data) -> None:
fold_info: CVEvaluation.FoldInfo = msg.data
if info.scores[metric.name] < 0.8:
msg.respond(True)
else:
msg.respond(False)
history = mlp_classifier.optimize(
target=task,
metric=metric,
on_trial_exception="continue",
scheduler=scheduler,
)
The notable change was that the callback of @task.on("fold-evaluated") returned a value, shielding the user from the Comm implementation detail. This required updating the Event system to allow for returned values from handlers, necessitating the updated signature of all existing events in amltk. Hence a lot of changes.
Last few changes were around code simplification while I was going through implementing this.
The concept of an Evaluator as something more than just a callable function was removed as there was little benefit seen while implementing this. It would help us shove more behavior into pipeline.optimize but for now, we should not hide people from the other parts of AMLTK which provides a lot more utility and customization, not afforded by just parameters to a function. (Maybe we revisit this at some point but it helps reduce maint. complexity)
The @events for the Scheduler and Task all had strings such as "on_result" while things like the Comm.Plugin had events like "comm-message". These were unified more to use hyphens ("-") and to remove the "on" part such that user code reads nicer.
# Before
@task.on("on_result")
def f(...): ...
@task.on("on_future_submitted")
def f(...): ...
# After
@task.on("result")
def f(...): ...
@task.on("future-submitted")
def f(...): ...
# Note comm message using hyphens
@task.on("comm-message")
def f(...): ...
This is backwards breaking, but given we advertise using task.on_result as a shorthand that should work the same, I'm okay with making the breaking change sooner rather than later. I think we should still advertise the on_<event> where it's possible as it provides a level of type safety we just can't do with strings.
The requirement of copy() for plugins is hard to meet for stateful plugins. While nice in principal, this was all done to support task.copy(). In all my use of the library so far, I've never once needed this feature. Remove for now.
Metric.compare(v1, v2) for knowing if a value is Metric.Comparison.EQUAL/WORSE/BETTER to not have to constantly check if a metric is minimize or not. This also allows match statements which are a bit more explicit in terms of behaviour routing and allow us to recommend this as public API for a Metric, hiding details behind curtains.
Alright, another big one.
Major feature was to implement CVEarlyStopping. This will need to be documented once the optimization documentation gets updated. There also needs to be some default variants select-able by keyword but for now that's the subject of experimentation.
There were some larger updates that needed to be done to enable this, namely:
Comm.Plugin
such that the worker process can communicate with the main process, where theCVEarlyStopper
lives.Comm
relies on the workerrequest()
'ing whether it should stop and the main processmsg.respond(should_stop)
'ing on what to do.To illustrate the more explicit callback method, the above way of doing things is more or less equivalent to the following:
The notable change was that the callback of
@task.on("fold-evaluated")
returned a value, shielding the user from theComm
implementation detail. This required updating theEvent
system to allow for returned values from handlers, necessitating the updated signature of all existing events in amltk. Hence a lot of changes.Last few changes were around code simplification while I was going through implementing this.
The concept of an
Evaluator
as something more than just a callable function was removed as there was little benefit seen while implementing this. It would help us shove more behavior intopipeline.optimize
but for now, we should not hide people from the other parts of AMLTK which provides a lot more utility and customization, not afforded by just parameters to a function. (Maybe we revisit this at some point but it helps reduce maint. complexity)The
@events
for theScheduler
andTask
all had strings such as"on_result"
while things like theComm.Plugin
had events like"comm-message"
. These were unified more to use hyphens ("-"
) and to remove the"on"
part such that user code reads nicer.This is backwards breaking, but given we advertise using
task.on_result
as a shorthand that should work the same, I'm okay with making the breaking change sooner rather than later. I think we should still advertise theon_<event>
where it's possible as it provides a level of type safety we just can't do with strings.The requirement of
copy()
for plugins is hard to meet for stateful plugins. While nice in principal, this was all done to supporttask.copy()
. In all my use of the library so far, I've never once needed this feature. Remove for now.Metric.compare(v1, v2)
for knowing if a value isMetric.Comparison.EQUAL/WORSE/BETTER
to not have to constantly check if a metric is minimize or not. This also allows match statements which are a bit more explicit in terms of behaviour routing and allow us to recommend this as public API for aMetric
, hiding details behind curtains.