tldr; callback hooks after each split evaluation that has access to all data and model as well as a callback hook as the last step before reporting back from the worker. Both callbacks are serialized and called inside the worker, therefore they have more access to data without worrying about serialization out of the worker.
Tradeoff for this is that they can not freely access information from main process and should be lightweight, as these callbacks get serialized into the worker.
Tried to ensure no data lives longer than it needs to to prevent split data/models living longer than needed.
Will fill this description better once I've tested it a bit more.
@LennartPurucker this refactor is pretty unreviewable by the git diffs so I hope the tests illustrate what's possible:
# Called after each split inside the worker, just before we release reference to
# the data and model (if we don't need to keep those references)
def post_split_callback(
trial: Trial,
split_number: int,
info: CVEvaluation.PostSplitInfo,
) -> CVEvaluation.PostSplitInfo:
# Should get the test data if it was passed in as it is in the function below
assert info.X_test is not None
assert info.y_test is not None
check_is_fitted(info.model)
trial.summary[f"post_split_{split_number}"] = split_number
return info
def test_post_split(tmp_path: Path) -> None:
pipeline = Component(DecisionTreeClassifier, config={"max_depth": 1})
x, y = data_for_task_type("binary")
TEST_SIZE = 2
x_test, y_test = x[:TEST_SIZE], y[:TEST_SIZE]
NSPLITS = 3
evaluator = CVEvaluation(
x,
y,
X_test=x_test,
y_test=y_test,
n_splits=NSPLITS,
working_dir=tmp_path,
on_error="raise",
post_split=post_split_callback, # <- Passed here
)
trial = Trial.create("test", bucket=tmp_path / "trial", metrics=Metric("accuracy"))
report = evaluator.fn(trial, pipeline)
for i in range(NSPLITS):
assert f"post_split_{i}" in report.summary
assert report.summary[f"post_split_{i}"] == i
# Called once the whole evaluation has finished, as the last step inside the worker.
def post_processing_callback(
report: Trial.Report,
pipeline: Node, # noqa: ARG001
eval_info: CVEvaluation.CompleteEvalInfo,
) -> Trial.Report:
# We should have no models in our post processing since we didn't ask for it
# with `post_processing_requires_models`. This is to prevent holding
# onto models in memory unless it's explicitly needed.
assert eval_info.models is None
# However we specify to store models below, so we should have models in
# the storage
for i in range(eval_info.max_splits):
assert f"model_{i}.pkl" in report.storage
trial = report.trial
# Delete the models
trial.delete_from_storage(
[f"model_{i}.pkl" for i in range(eval_info.max_splits)],
)
# Return some bogy number as the metric value
return trial.success(accuracy=0.123)
def test_post_processing_no_models(tmp_path: Path) -> None:
pipeline = Component(DecisionTreeClassifier, config={"max_depth": 1})
x, y = data_for_task_type("binary")
evaluator = CVEvaluation(
x,
y,
working_dir=tmp_path,
on_error="raise",
post_processing=post_processing_callback,
store_models=True,
)
trial = Trial.create("test", bucket=tmp_path / "trial", metrics=Metric("accuracy"))
report = evaluator.fn(trial, pipeline)
# The chaotic post processing
assert report.values["accuracy"] == 0.123
assert len(report.storage) == 0
tldr; callback hooks after each split evaluation that has access to all data and model as well as a callback hook as the last step before reporting back from the worker. Both callbacks are serialized and called inside the worker, therefore they have more access to data without worrying about serialization out of the worker.
Tradeoff for this is that they can not freely access information from main process and should be lightweight, as these callbacks get serialized into the worker.
Tried to ensure no data lives longer than it needs to to prevent split data/models living longer than needed.
Will fill this description better once I've tested it a bit more.
@LennartPurucker this refactor is pretty unreviewable by the git diffs so I hope the tests illustrate what's possible: