automl / amltk

A build-it-yourself AutoML Framework
https://automl.github.io/amltk/
BSD 3-Clause "New" or "Revised" License
68 stars 6 forks source link

feat(CVEvaluator): Add feature for post_split and post_processing #260

Closed eddiebergman closed 9 months ago

eddiebergman commented 9 months ago

tldr; callback hooks after each split evaluation that has access to all data and model as well as a callback hook as the last step before reporting back from the worker. Both callbacks are serialized and called inside the worker, therefore they have more access to data without worrying about serialization out of the worker.

Tradeoff for this is that they can not freely access information from main process and should be lightweight, as these callbacks get serialized into the worker.

Tried to ensure no data lives longer than it needs to to prevent split data/models living longer than needed.

Will fill this description better once I've tested it a bit more.


@LennartPurucker this refactor is pretty unreviewable by the git diffs so I hope the tests illustrate what's possible:


# Called after each split inside the worker, just before we release reference to
# the data and model (if we don't need to keep those references)
def post_split_callback(
    trial: Trial,
    split_number: int,
    info: CVEvaluation.PostSplitInfo,
) -> CVEvaluation.PostSplitInfo:
    # Should get the test data if it was passed in as it is in the function below
    assert info.X_test is not None
    assert info.y_test is not None
    check_is_fitted(info.model)

    trial.summary[f"post_split_{split_number}"] = split_number
    return info

def test_post_split(tmp_path: Path) -> None:
    pipeline = Component(DecisionTreeClassifier, config={"max_depth": 1})
    x, y = data_for_task_type("binary")
    TEST_SIZE = 2
    x_test, y_test = x[:TEST_SIZE], y[:TEST_SIZE]

    NSPLITS = 3
    evaluator = CVEvaluation(
        x,
        y,
        X_test=x_test,
        y_test=y_test,
        n_splits=NSPLITS,
        working_dir=tmp_path,
        on_error="raise",
        post_split=post_split_callback,  # <- Passed here
    )
    trial = Trial.create("test", bucket=tmp_path / "trial", metrics=Metric("accuracy"))
    report = evaluator.fn(trial, pipeline)

    for i in range(NSPLITS):
        assert f"post_split_{i}" in report.summary
        assert report.summary[f"post_split_{i}"] == i

# Called once the whole evaluation has finished, as the last step inside the worker.
def post_processing_callback(
    report: Trial.Report,
    pipeline: Node,  # noqa: ARG001
    eval_info: CVEvaluation.CompleteEvalInfo,
) -> Trial.Report:
    # We should have no models in our post processing since we didn't ask for it
    # with `post_processing_requires_models`. This is to prevent holding
    # onto models in memory unless it's explicitly needed.
    assert eval_info.models is None

    # However we specify to store models below, so we should have models in
    # the storage
    for i in range(eval_info.max_splits):
        assert f"model_{i}.pkl" in report.storage

    trial = report.trial

    # Delete the models
    trial.delete_from_storage(
        [f"model_{i}.pkl" for i in range(eval_info.max_splits)],
    )

    # Return some bogy number as the metric value
    return trial.success(accuracy=0.123)

def test_post_processing_no_models(tmp_path: Path) -> None:
    pipeline = Component(DecisionTreeClassifier, config={"max_depth": 1})
    x, y = data_for_task_type("binary")
    evaluator = CVEvaluation(
        x,
        y,
        working_dir=tmp_path,
        on_error="raise",
        post_processing=post_processing_callback,
        store_models=True,
    )
    trial = Trial.create("test", bucket=tmp_path / "trial", metrics=Metric("accuracy"))
    report = evaluator.fn(trial, pipeline)

    # The chaotic post processing
    assert report.values["accuracy"] == 0.123
    assert len(report.storage) == 0