mlr-org / mlr3tuning

Hyperparameter optimization package of the mlr3 ecosystem
https://mlr3tuning.mlr-org.com/
GNU Lesser General Public License v3.0
52 stars 5 forks source link

Early stopping with XGBoost graph learner [feature request] #376

Closed bblodfon closed 1 year ago

bblodfon commented 1 year ago

Hi @be-marc,

It will be really cool if we could have the early stopping callback for XGBoost work on a graph learner. I recently adapted the survival XGBoost learner to work with early stopping but when I tried to use it as a part of a graph learner as in below it wouldn't work:

library(mlr3proba)
#> Loading required package: mlr3
library(mlr3extralearners)
library(mlr3tuning)
#> Loading required package: paradox
library(survmob)

s = SurvLPS$new(ids = 'xgboost_cox_early')
dt = s$lrn_tbl()
grlrn = dt$learner[[1L]] # XGBoost Cox with distr prediction graph learner
grlrn
#> <GraphLearner:XGBoostCox>: Extreme Gradient Boosting (Cox)
#> * Model: -
#> * Parameters: XGBoostCox.booster=gbtree,
#>   XGBoostCox.early_stopping_set=test, XGBoostCox.nrounds=1,
#>   XGBoostCox.nthread=16, XGBoostCox.objective=survival:cox,
#>   XGBoostCox.verbose=0, compose_distr.form=ph,
#>   compose_distr.overwrite=TRUE
#> * Packages: mlr3, mlr3pipelines, mlr3proba, survival, distr6,
#>   mlr3extralearners, xgboost
#> * Predict Types:  [crank], distr, lp, response
#> * Feature Types: logical, integer, numeric, character, factor, ordered,
#>   POSIXct
#> * Properties: featureless, hotstart_backward, hotstart_forward,
#>   importance, loglik, missings, oob_error, selected_features, weights
param_set = dt$param_set[[1L]]
param_set
#> <ParamSet>
#>                             id    class     lower       upper nlevels
#> 1:          XGBoostCox.nrounds ParamInt 100.00000 1500.000000    1401
#> 2:              XGBoostCox.eta ParamDbl  -9.21034   -1.203973     Inf
#> 3:        XGBoostCox.max_depth ParamInt   2.00000    8.000000       7
#> 4: XGBoostCox.min_child_weight ParamDbl   0.00000    4.852030     Inf
#>           default value
#> 1: <NoDefault[3]>      
#> 2: <NoDefault[3]>      
#> 3: <NoDefault[3]>      
#> 4: <NoDefault[3]>      
#> Trafo is set.

task = tsk('lung')

xgb_at = AutoTuner$new(
  learner = grlrn,
  resampling = rsmp('holdout'),
  measure = msr('surv.brier'),
  search_space = param_set,
  terminator = trm('evals', n_evals = 5),
  tuner = tnr('random_search'),
  callbacks = clbk("mlr3tuning.early_stopping")
)
xgb_at$train(task)
#> Error: <GraphLearner:XGBoostCox> is incompatible with <CallbackTuning:mlr3tuning.early_stopping>

Created on 2023-01-09 with reprex v2.0.2

be-marc commented 1 year ago

Hey @bblodfon,

Graph Learners are currently excluded for two reasons.

  1. The XGBoost model must be found in the graph.
  2. Preprocessing steps are currently not applied to the test set/ early stopping set when training the XGBoost model.

We are working on issue 2 (@mb706). If your pipeline has no preprocessing steps e.g. only po("subsample") or is a branching pipeline, you can write a custom callback. For this, we only need to find the model in the graph (issue 1).

learner = lrn("classif.xgboost")
graph_learner = po("subsample", frac = 0.5) %>>% learner
graph_learner$train(tsk("pima"))

# search for model in state
graph_learner$state

# Path to XGBoost model
graph_learner$state$classif.xgboost$model

Next, we write the custom callback. We use the default early stopping callback as a template. I commented the changes.

callback = callback_tuning("mlr3tuning.early_stopping_graph_learner",
  on_optimization_begin = function(callback, context) {
    learner = context$instance$objective$learner

    # Remove checks

    callback$state$store_models = context$instance$objective$store_models
    context$instance$objective$store_models = TRUE
  },

  on_eval_after_benchmark = function(callback, context) {
    callback$state$max_nrounds = mlr3misc::map_dbl(context$benchmark_result$resample_results$resample_result, function(rr) {
        max(mlr3misc::map_dbl(mlr3misc::get_private(rr)$.data$learner_states(mlr3misc::get_private(rr)$.view), function(state) {
          # Use path to XGBoost model in graph learner
          state$model$classif.xgboost$model$best_iteration 
        }))
    })
  },

  on_eval_before_archive = function(callback, context) {
    data.table::set(context$aggregated_performance, j = "max_nrounds", value = callback$state$max_nrounds)
    if (!callback$state$store_models) context$benchmark_result$discard(models = TRUE)
  },

  on_result = function(callback, context) {
    # Prefix parameters with learner id
    context$result$learner_param_vals[[1]]$classif.xgboost.early_stopping_rounds = NULL 
    context$result$learner_param_vals[[1]]$classif.xgboost.nrounds = context$instance$archive$best()$max_nrounds
    context$result$learner_param_vals[[1]]$classif.xgboost.early_stopping_set = "none"
    context$instance$objective$store_models = callback$state$store_models
  }
)

Tune with custom callback.

learner = lrn("classif.xgboost",
  eta = to_tune(1e-02, 1e-1, logscale = TRUE),
  early_stopping_rounds = 5,
  nrounds = 100,
  early_stopping_set = "test")

graph_learner = po("subsample", frac = 0.5) %>>% learner

instance = tune(
  method = tnr("random_search"),
  task = tsk("pima"),
  learner = graph_learner,
  resampling = rsmp("cv", folds = 3),
  measures = msr("classif.ce"),
  term_evals = 10,
  callbacks = callback
)
bblodfon commented 1 year ago

Nice, thanks so much @be-marc, it works!

Btw, is there a more automatic way to get the learner id? I have constructed two possible xgboost graph learners with different ids, XGBoostCox and XGBoostAFT so I am now simply checking the value of state$model$XGBoostCox if it is NULL or not in the above code to see which one I should use in the pipe path.

bblodfon commented 1 year ago

Okay, the answer was already there, since I can get the learner = context$instance$objective$learner and therefore I can always take the learner$id (of the graph learner) which it so happens in my case to be the same as the prefix in the parameters (that's how I constructed in the first place). But of course you wouldn't expect this to work in general!

bblodfon commented 1 year ago

So, the context$instance$objective$learner$id is only accessible (not NULL) during on_optimization_begin and or_result. In the case of on_eval_after_benchmark, I can't can't get it that way it seems.

be-marc commented 1 year ago

I don't know if I understand your question correctly yet. When we tune a graph learner, the id is always the id of the graph learner. The branch.selection parameter shows the id of the learner/ active path. You could access this parameter to get the right prefix in on_result. However, you only need this step when you use an AutoTuner or do nested resampling. If not you can remove on_result. Ids can be changed like this.


learners = list(
 # Change id of learner and prefix in graph learner parameter set
  kknn = lrn("classif.xgboost", id = "xgboost_1"), 
  svm = lrn("classif.xgboost", id = "xgboost_2")
)

graph = ppl("branch", lapply(learners, po))
learner = as_learner(graph)

# Change id of graph learner
learner$id = "branching"
``
bblodfon commented 1 year ago

Sorry, I know it is confusing! So what is happening here is that the id of the graph learner is the same as the id of the xgboost learner it encapsulates, i.e. see below how I create this graph learner object:

xgboost_cox = mlr3pipelines::ppl('distrcompositor',
              learner = lrn('surv.xgboost', nthread = nthreads,
                booster = 'gbtree', fallback = lrn('surv.kaplan'),
                objective = 'survival:cox', id = 'XGBoostCox'),
              estimator = 'kaplan',
              form = 'ph',
              overwrite = TRUE,
              graph_learner = TRUE
            )
xgboost_cox$id = 'XGBoostCox'
xgboost_cox$label = 'Extreme Gradient Boosting (Cox)'

I have a similar xgboost graph learner with AFT prefixes. So, I implement the callback like this and it now works. But I guess the context$instance$objective$learner$id is the graph learner's id which just happens to be the same as the prefix of the parameters of the xgboost learner as it is created from the above code (also see on_result code). It was just a convenience I had adopted to make the parameter spaces as well with the id prefix and it just so happened to be useful in here as well :)