Closed bblodfon closed 1 year ago
Hey @bblodfon,
Graph Learners are currently excluded for two reasons.
We are working on issue 2 (@mb706). If your pipeline has no preprocessing steps e.g. only po("subsample")
or is a branching pipeline, you can write a custom callback. For this, we only need to find the model in the graph (issue 1).
learner = lrn("classif.xgboost")
graph_learner = po("subsample", frac = 0.5) %>>% learner
graph_learner$train(tsk("pima"))
# search for model in state
graph_learner$state
# Path to XGBoost model
graph_learner$state$classif.xgboost$model
Next, we write the custom callback. We use the default early stopping callback as a template. I commented the changes.
callback = callback_tuning("mlr3tuning.early_stopping_graph_learner",
on_optimization_begin = function(callback, context) {
learner = context$instance$objective$learner
# Remove checks
callback$state$store_models = context$instance$objective$store_models
context$instance$objective$store_models = TRUE
},
on_eval_after_benchmark = function(callback, context) {
callback$state$max_nrounds = mlr3misc::map_dbl(context$benchmark_result$resample_results$resample_result, function(rr) {
max(mlr3misc::map_dbl(mlr3misc::get_private(rr)$.data$learner_states(mlr3misc::get_private(rr)$.view), function(state) {
# Use path to XGBoost model in graph learner
state$model$classif.xgboost$model$best_iteration
}))
})
},
on_eval_before_archive = function(callback, context) {
data.table::set(context$aggregated_performance, j = "max_nrounds", value = callback$state$max_nrounds)
if (!callback$state$store_models) context$benchmark_result$discard(models = TRUE)
},
on_result = function(callback, context) {
# Prefix parameters with learner id
context$result$learner_param_vals[[1]]$classif.xgboost.early_stopping_rounds = NULL
context$result$learner_param_vals[[1]]$classif.xgboost.nrounds = context$instance$archive$best()$max_nrounds
context$result$learner_param_vals[[1]]$classif.xgboost.early_stopping_set = "none"
context$instance$objective$store_models = callback$state$store_models
}
)
Tune with custom callback.
learner = lrn("classif.xgboost",
eta = to_tune(1e-02, 1e-1, logscale = TRUE),
early_stopping_rounds = 5,
nrounds = 100,
early_stopping_set = "test")
graph_learner = po("subsample", frac = 0.5) %>>% learner
instance = tune(
method = tnr("random_search"),
task = tsk("pima"),
learner = graph_learner,
resampling = rsmp("cv", folds = 3),
measures = msr("classif.ce"),
term_evals = 10,
callbacks = callback
)
Nice, thanks so much @be-marc, it works!
Btw, is there a more automatic way to get the learner id? I have constructed two possible xgboost graph learners with different ids, XGBoostCox
and XGBoostAFT
so I am now simply checking the value of state$model$XGBoostCox
if it is NULL or not in the above code to see which one I should use in the pipe path.
Okay, the answer was already there, since I can get the learner = context$instance$objective$learner
and therefore I can always take the learner$id
(of the graph learner) which it so happens in my case to be the same as the prefix in the parameters (that's how I constructed in the first place). But of course you wouldn't expect this to work in general!
So, the context$instance$objective$learner$id
is only accessible (not NULL) during on_optimization_begin
and or_result
. In the case of on_eval_after_benchmark
, I can't can't get it that way it seems.
I don't know if I understand your question correctly yet. When we tune a graph learner, the id is always the id of the graph learner. The branch.selection
parameter shows the id of the learner/ active path. You could access this parameter to get the right prefix in on_result
. However, you only need this step when you use an AutoTuner
or do nested resampling. If not you can remove on_result
. Ids can be changed like this.
learners = list(
# Change id of learner and prefix in graph learner parameter set
kknn = lrn("classif.xgboost", id = "xgboost_1"),
svm = lrn("classif.xgboost", id = "xgboost_2")
)
graph = ppl("branch", lapply(learners, po))
learner = as_learner(graph)
# Change id of graph learner
learner$id = "branching"
``
Sorry, I know it is confusing! So what is happening here is that the id
of the graph learner is the same as the id
of the xgboost learner it encapsulates, i.e. see below how I create this graph learner object:
xgboost_cox = mlr3pipelines::ppl('distrcompositor',
learner = lrn('surv.xgboost', nthread = nthreads,
booster = 'gbtree', fallback = lrn('surv.kaplan'),
objective = 'survival:cox', id = 'XGBoostCox'),
estimator = 'kaplan',
form = 'ph',
overwrite = TRUE,
graph_learner = TRUE
)
xgboost_cox$id = 'XGBoostCox'
xgboost_cox$label = 'Extreme Gradient Boosting (Cox)'
I have a similar xgboost graph learner with AFT prefixes. So, I implement the callback like this and it now works. But I guess the context$instance$objective$learner$id
is the graph learner's id which just happens to be the same as the prefix of the parameters of the xgboost learner as it is created from the above code (also see on_result
code). It was just a convenience I had adopted to make the parameter spaces as well with the id
prefix and it just so happened to be useful in here as well :)
Hi @be-marc,
It will be really cool if we could have the early stopping callback for XGBoost work on a graph learner. I recently adapted the survival XGBoost learner to work with early stopping but when I tried to use it as a part of a graph learner as in below it wouldn't work:
Created on 2023-01-09 with reprex v2.0.2