tidymodels / bonsai

parsnip wrappers for tree-based models
https://bonsai.tidymodels.org
Other
51 stars 7 forks source link

Feature idea - provide custom validation sets for early stopping #48

Open dfsnow opened 2 years ago

dfsnow commented 2 years ago

Thanks for creating this excellent package. I created a similar fork of treesnip but am planning to replace it with {bonsai} in all our production models.

One feature that I think would be incredibly useful in {bonsai} is the ability to provide custom validation sets during early stopping (instead of using a random split of the training data). This would have a few potential benefits:

  1. More training data. In many cases, you're already going to have a validation set set aside from a classic train, validate, test split. Currently, {bonsai} will further split the train data into train subset and validation specifically for early stopping sets. Instead, it would be ideal to be able to pass the validate set directly. This would mean all of train would be used for training.
  2. Ability to do more complex cross-validation. Certain cross-validation techniques (rolling origin, spatial, etc.) don't rely on a random sample of the training data and instead use some sort of partitioning (time or geographic). Allowing custom validation data would let users use the "correct" validation set for early stopping when using these more complex methods.
  3. Better integration with tidymodels. Tidymodels supports k-fold and other types of cross-validation. Using the validation set created for each fold rather than splitting a separate validation set specifically for early stopping would be much simpler.

Let me know if this is out-of-scope for this project. If not, I'm happy to contribute if needed.

simonpcouch commented 2 years ago

Thanks for the issue! I'm on board. :)

Related to tidymodels/parsnip#760, and tidymodels/parsnip#765.

My response for the analogous parsnip issues reflects where my thinking is at with this in bonsai as well.

This is an interesting idea and one that we ought to consider. xgboost and lightgbm's interfaces for validation sets allow for a lot of user control, but we'd need to think carefully about what a tidymodels-esque interface might feel like here.

This won't be on the top of our to-do list for now, but will leave this open as a possible future extension. :)

dfsnow commented 2 years ago

Great! Thanks for the quick response. Looks like there's already a PR in {parsnip} for exactly this @ https://github.com/tidymodels/parsnip/pull/771. I'll await that merge and then happy to assist with any further work needed to integrate it into {bonsai}.

jameslamb commented 2 years ago

Whenever you or others here pick this up @simonpcouch , @ me if you need any help with how to do this in {lightgbm}.

There is a LightGBM-y way to create validation sets that is slightly different from "just subset rows". See https://lightgbm.readthedocs.io/en/latest/R/reference/lgb.Dataset.create.valid.html.

diegoperoni commented 3 months ago

Hi, I wrote a simple fix to allow an alternative way to specify a custom validation set using "validation" param. Using this code with bonsai v0.3.0 user can provide:

Example:

validation = 0.3 # default random sample (current solution)

validation = c(0.7, 0.9) # alternative solution to select a continuous subset starting from 70% and ending at 90% of the training set.

Here the code to replace the internal function after bonsai library 0.3.0 has been loaded.

Hope it is useful

Regards

 utils::assignInNamespace(
  x  = "process_data",
  ns = "bonsai",
  value = function(args, x, y, weights, validation, missing_validation) {

    #                                           trn_index       | val_index
    #                                         ----------------------------------
    #  needs_validation &  missing_validation | 1:n               1:n
    #  needs_validation & !missing_validation | sample(1:n, m)    setdiff(trn_index, 1:n)
    # !needs_validation &  missing_validation | 1:n               NULL
    # !needs_validation & !missing_validation | sample(1:n, m)    setdiff(trn_index, 1:n)

    n <- nrow(x)
    needs_validation <- !is.null(args$params$early_stopping_round)
    if (!needs_validation) {
      # If early_stopping_round isn't set, clear it from arguments actually
      # passed to LightGBM.
      args$params$early_stopping_round <- NULL
    }

    if (missing_validation) {
      trn_index <- 1:n
      if (needs_validation) {
        val_index <- trn_index
      } else {
        val_index <- NULL
      }
    } else {
      if (length(validation)==2) {
        # validation range percent bounds c(lower, higher)
        l <- floor(n * validation[1]) + 1
        h <- floor(n * validation[2])
        val_index <- c(l:h)
        trn_index <- setdiff(1:n, val_index)
      } else {
        # validation percent as scalar (default method)
        m <- min(floor(n * (1 - validation)) + 1, n - 1)
        trn_index <- sample(1:n, size = max(m, 2))
        val_index <- setdiff(1:n, trn_index)
      }
    }

    data_args <-
      c(
        list(
          data = bonsai:::prepare_df_lgbm(x[trn_index, , drop = FALSE]),
          label = y[trn_index],
          categorical_feature = bonsai:::categorical_columns(x[trn_index, , drop = FALSE]),
          params = c(list(feature_pre_filter = FALSE), args$params),
          weight = weights[trn_index]
        ),
        args$main_args_dataset
      )

    args$main_args_train$data <-
      rlang::eval_bare(
        rlang::call2("lgb.Dataset", !!!data_args, .ns = "lightgbm")
      )

    if (!is.null(val_index)) {
      valids_args <-
        c(
          list(
            data = bonsai:::prepare_df_lgbm(x[val_index, , drop = FALSE]),
            label = y[val_index],
            categorical_feature = bonsai:::categorical_columns(x[val_index, , drop = FALSE]),
            params = list(feature_pre_filter = FALSE, args$params),
            weight = weights[val_index]
          ),
          args$main_args_dataset
        )

      args$main_args_train$valids <-
        list(
          validation =
            rlang::eval_bare(
              rlang::call2("lgb.Dataset", !!!valids_args, .ns = "lightgbm")
            )
        )
    }

    args
})