alteryx / evalml

EvalML is an AutoML library written in python.
https://evalml.alteryx.com
BSD 3-Clause "New" or "Revised" License
734 stars 83 forks source link

Model debugging: Add ability to compute and store graphs/stats on each CV fold during automl search #1111

Open dsherry opened 3 years ago

dsherry commented 3 years ago

Goal If users intend to compute graphs and stats on any of the models trained during each CV fold, we should design an API which allows them to do so.

I think we should make this general enough to support running any function on each CV fold and storing the results on the automl object. The results would be most naturally stored in the automl results dict, but we should build a public API to access the results.

Proposal Add a cv_callbacks list as an init arg to AutoMLSearch, default None/empty. If provided, each callback will be run during the evaluation of each CV fold. The callback will be passed the training split, the validation split and the fitted model. The results from the callback will end up in results in each CV splits' dict, as a list stored under the name cv_callback_results

Next Steps Before we build this, we should ask ourselves why this feature is necessary. Given our current data splitting plan, we want the ability to generate graphs and stats on the 1st CV fold, to correspond with the "validation" score which will be computed on the same data (see #1049). However, if we only need to compute graphs/stats on the first validation fold, or if we decide to create a separate split for model selection, the requirements here will change. So:

Background We used to have this ability, but we removed it because we were mixing the Objective abstraction with the graph/stats methods. That's why I advocate for a design which is agnostic to the type of data being computed.

dsherry commented 3 years ago

Status: we're still waiting on conversation with @gsheni on this :)

dsherry commented 3 years ago

Status: we had conversation. The use-case for computing graphs on training / CV data is to debug models. I.e., if a model has a completely different predicted vs actual, ROC curve, or other plot across the various CV folds, this could indicate high variance during training which could mean the model is unstable on this dataset (or that the dataset is of poor quality).

I like this idea. However I think we should punt on it until we have a good example of the use-case, or a request for this capability from users.