curso-r / treesnip

Parsnip backends for `tree`, `lightGBM` and `Catboost`
https://curso-r.github.io/treesnip
GNU General Public License v3.0
85 stars 13 forks source link

Error using tune_grid with sample_size for lightgbm #60

Open felxcon opened 2 years ago

felxcon commented 2 years ago

Hi everybody! The problem

I am following https://www.r-bloggers.com/2020/08/how-to-use-lightgbm-with-tidymodels/. When implemented all supported boost_tree() parameters for lightgbm I get an error message with trying to tune() sample_size():

(Warning message: All models failed. See the .notes column.

lgbm_tuned$.notes[[1]]

A tibble: 1 x 1

.notes

1 internal: Error: All unnamed arguments must be length 1)

I'm having trouble with solving this issue, tried approaches of others on tune()+sample_size or any (few) hints for internal: Error: ... could not solve that. Any suggestions? I would be very glad for expertise:)

! I also wonder where to adapt lightgbm default settings (its not in "set_engine" .X); I would like to run lightgbm with

num_threads = 3,

         # #num_leaves = 2^max_depth,
         # #early_stopping_round = 0.1*num_iterations,
         # boosting = gbdt,
         # bagging_freq = 0.5,
         # tree_learner = data,
         # extra_trees = T,
         # monotone_constraints_method = advanced,
         # feature_pre_filter = F,
         # pre_partition = T

Thanks a lot, Felix

Reproducible example

#
library("lightgbm")
#> Lade nötiges Paket: R6
# generic lightgbm script
# from https://www.tychobra.com/posts/2020-05-19-xgboost-with-tidymodels/
# Their design and description is very sound and I iterate on their work.
# # data
# library(AmesHousing)
# 
# data cleaning
library(janitor)
#> 
#> Attache Paket: 'janitor'
#> Die folgenden Objekte sind maskiert von 'package:stats':
#> 
#>     chisq.test, fisher.test
# data prep
library(dplyr)
#> 
#> Attache Paket: 'dplyr'
#> Das folgende Objekt ist maskiert 'package:lightgbm':
#> 
#>     slice
#> Die folgenden Objekte sind maskiert von 'package:stats':
#> 
#>     filter, lag
#> Die folgenden Objekte sind maskiert von 'package:base':
#> 
#>     intersect, setdiff, setequal, union
# tidymodels
library(recipes)
#> 
#> Attache Paket: 'recipes'
#> Das folgende Objekt ist maskiert 'package:stats':
#> 
#>     step
library(rsample)
library(parsnip)
library(tune)
#> Registered S3 method overwritten by 'tune':
#>   method                   from   
#>   required_pkgs.model_spec parsnip
#> 
#> Attache Paket: 'tune'
#> Das folgende Objekt ist maskiert 'package:recipes':
#> 
#>     tune_args
library(dials)
#> Lade nötiges Paket: scales
library(workflows)
library(yardstick)
#> For binary classification, the first factor level is assumed to be the event.
#> Use the argument `event_level = "second"` to alter this as needed.
library(treesnip)
# need this too later.
library(ggplot2)

# speed up computation with parrallel processing (optional)
# WARNING, THIS RESULTS IN ERRORS FOR WINDOWS. 
library(doParallel)
#> Lade nötiges Paket: foreach
#> Lade nötiges Paket: iterators
#> Lade nötiges Paket: parallel
all_cores <- 3 #parallel::detectCores(logical = T) 
# in xgboost the cores were not optimally used on my mac, but lgbm is filling
# it tot the brim
registerDoParallel(cores = all_cores) 

#import datasheets
#MMMF_all_cleansed_site_properties_work_clean_orig = read.csv("[...].csv")
#MMMF_all_cleansed_site_properties_work_clean = MMMF_all_cleansed_site_properties_work_clean_orig

#mmmf_mixedVAR_simple_qf = MMMF_all_cleansed_site_properties_work_clean[,c("QF2020scenario",
#                                                                          "LULCTypeCLC2018",
#                                                                         "CurveNumber",
#                                                                          "NumberRainEvents2020",
#                                                                          "Precip2020",
#                                                                          "cBedrockPerm",
#                                                                          "cSlope",
#                                                                          "cSoilPerm",
#                                                                          "cTreeline",
#                                                                          "watershed",
#                                                                          "MMMFsize",
#                                                                          "Compactness_IPO",
#                                                                          "cBayKompV",
#                                                                          "BiomassProd",
#                                                                          "cIUCNProtArea",
#                                                                          "cRootDepthSpace",
#                                                                          "cSunnyAspect",
#                                                                          "cWindAspect",
#                                                                          "Soil_pH")] %>% mutate(across(c(2,3,6:13,15:18), factor))

mmmf_mixedVAR_simple_qf  = structure(list(QF2020scenario = c(41.6104850769043, 68.6856307983398, 
57.3654022216797, 28.7580642700195, 45.1602096557617, 47.2106628417969, 
68.6856307983398, 71.7784652709961, 24.4756469726562, 40.3135414123535, 
23.0632152557373, 24.1012935638428, 68.1585693359375, 63.6638069152832, 
41.6104850769043, 35.8073768615723, 68.6856307983398, 54.6796913146973, 
48.6273994445801, 41.6104850769043, 41.6104850769043, 89.0888595581055, 
50.85595703125, 1093.50598144531, 41.8495635986328, 75.2088851928711, 
37.8572235107422, 76.5240020751953, 27.1052436828613, 34.5876998901367, 
28.7580642700195, 42.9002914428711, 27.1052436828613, 29.2019214630127, 
34.5320816040039, 40.723876953125, 57.3654022216797, 54.1622505187988, 
41.5773963928223, 39.8819427490234, 32.0537185668945, 1339.29736328125, 
40.9108467102051, 41.5037727355957, 28.6135368347168, 47.2106628417969, 
68.6856307983398, 36.1073341369629, 40.440845489502, 48.0962562561035, 
74.1079177856445, 23.0632152557373, 58.8290863037109, 50.85595703125, 
24.9130783081055, 87.9564056396484, 65.1510391235352, 28.7580642700195, 
28.6135368347168, 48.6273994445801, 56.6535720825195, 31.4044914245605, 
89.0888595581055, 42.9002914428711, 40.723876953125, 54.8065490722656, 
48.5243873596191, 1413.09387207031, 41.2597579956055, 22.1660785675049, 
39.8819427490234, 27.2910995483398, 56.6535720825195, 40.723876953125, 
41.6104850769043, 58.8290863037109, 37.8572235107422, 34.5320816040039, 
79.2940444946289, 22.6065940856934, 57.3654022216797, 77.4911727905273, 
26.6769046783447, 74.1079177856445, 45.1602096557617, 79.2940444946289, 
36.1073341369629, 28.7580642700195, 68.1585693359375, 46.0501861572266, 
27.2910995483398, 48.1491203308105, 71.7784652709961, 68.0657424926758, 
54.1622505187988, 44.9057807922363, 26.5627517700195, 93.6683654785156, 
23.0632152557373, 38.9476852416992), CurveNumber = structure(c(13L, 
18L, 18L, 45L, 13L, 13L, 18L, 18L, 13L, 40L, 13L, 13L, 3L, 18L, 
28L, 28L, 18L, 13L, 13L, 1L, 1L, 3L, 3L, 18L, 28L, 3L, 1L, 18L, 
28L, 13L, 32L, 18L, 1L, 28L, 45L, 13L, 18L, 28L, 40L, 18L, 13L, 
32L, 18L, 18L, 32L, 28L, 18L, 13L, 18L, 18L, 18L, 32L, 3L, 1L, 
1L, 38L, 3L, 32L, 1L, 3L, 40L, 32L, 3L, 18L, 13L, 13L, 13L, 1L, 
3L, 13L, 18L, 13L, 1L, 13L, 1L, 18L, 13L, 1L, 18L, 13L, 18L, 
38L, 13L, 18L, 13L, 38L, 13L, 32L, 18L, 18L, 1L, 28L, 18L, 3L, 
13L, 13L, 13L, 18L, 1L, 13L), .Label = c("50", "53", "56", "57.5", 
"59.5", "60", "60.5", "61", "63", "63.5", "64", "64.5", "65", 
"66.5", "67.5", "68.5", "69.5", "70", "71", "72", "73", "73.5", 
"74", "74.5", "75.5", "76", "76.5", "77", "77.5", "78", "78.5", 
"79", "80", "80.5", "81", "81.5", "82", "83", "83.5", "84", "84.5", 
"85", "85.5", "86", "87", "87.5", "88", "88.5", "89", "90", "92", 
"93", "95"), class = "factor"), Precip2020 = c(105.631990780906, 
105.631990780906, 113.525026987469, 95.7414173890674, 108.81347918132, 
111.585622257657, 105.631990780906, 107.371452626728, 108.427198425172, 
104.19571208197, 106.32300672077, 108.237226047213, 108.663529055459, 
105.526280009557, 105.631990780906, 103.14488716731, 105.631990780906, 
117.181404984187, 111.968032095167, 105.631990780906, 105.631990780906, 
117.875289985112, 114.660130235884, 91.1255002702986, 105.89557216281, 
117.985597080655, 105.526280009557, 110.971759538802, 113.277256125496, 
103.226915200551, 95.7414173890674, 107.448963218265, 113.277256125496, 
115.204860694825, 102.406323243701, 108.663529055459, 113.525026987469, 
116.909620224483, 106.32300672077, 104.19571208197, 100.017380517627, 
111.608108346424, 105.95153774534, 105.95153774534, 96.1293804077875, 
111.585622257657, 105.631990780906, 117.757825435154, 105.14771898966, 
111.93966234298, 108.81347918132, 106.32300672077, 103.226915200551, 
114.660130235884, 111.252964557163, 116.909620224483, 105.89557216281, 
95.7414173890674, 96.1293804077875, 111.968032095167, 118.493426474314, 
99.5515276136853, 117.875289985112, 107.448963218265, 108.663529055459, 
117.875289985112, 112.081975460053, 117.757825435154, 107.530949713692, 
105.14771898966, 104.19571208197, 113.237426341526, 118.493426474314, 
108.663529055459, 105.631990780906, 103.226915200551, 105.526280009557, 
102.406323243701, 112.081975460053, 105.95153774534, 113.525026987469, 
111.585622257657, 111.93966234298, 108.81347918132, 108.81347918132, 
112.081975460053, 117.757825435154, 95.7414173890674, 108.663529055459, 
111.608108346424, 113.237426341526, 111.900469795106, 107.371452626728, 
108.146229834784, 116.909620224483, 108.206884308467, 112.154623977722, 
119.449564721849, 106.32300672077, 105.89557216281)), row.names = c(NA, 
100L), class = "data.frame")

# set the random seed so we can reproduce any simulated results.
set.seed(1234)

#--------------------------------------------------------------#
#https://github.com/tidymodels/tune/issues/460#issue-1151992023
# -> xgboost works
library(tidymodels)
library(xgboost)

data_train <- mmmf_mixedVAR_simple_qf[-(1:10), ]
data_test  <- mmmf_mixedVAR_simple_qf[  1:10 , ]
folds <- vfold_cv(data_train, v = 3, strata = QF2020scenario)
folds

bt_cls_spec <- 
  boost_tree(trees = 15, sample_size = tune()) %>% 
  set_mode("regression") %>% 
  set_engine("xgboost")

bt_cls_spec %>% 
  tune_grid(QF2020scenario ~ .,
            resamples = folds,
            grid = 5)
#--------------------------------------------------------------#

# load the housing data and clean names
mmmf_mixedVAR_simple_janitor_clean = mmmf_mixedVAR_simple_qf %>% janitor::clean_names()

# split into training and testing datasets. Stratify by Sale price 
mmmf_mixedVAR_simple_janitor_clean_split <- rsample::initial_split(
  mmmf_mixedVAR_simple_janitor_clean, 
  prop = 0.8, 
  strata = qf2020scenario
)

# Pre processing 
preprocessing_recipe <- 
  recipes::recipe(qf2020scenario ~ ., data = training(mmmf_mixedVAR_simple_janitor_clean_split)) %>%
  #convert categorical variables to factors
  recipes::step_string2factor(all_nominal()) %>%
  # combine low frequency factor levels
  recipes::step_other(all_nominal(), threshold = 0.01) %>%
  # remove no variance predictors which provide no predictive information 
  recipes::step_nzv(all_nominal()) %>%
  prep()

# Cross validate 
mmmf_mixedVAR_simple_janitor_clean_split_preproc_cv_folds <- 
  recipes::bake(
    preprocessing_recipe, 
    new_data = training(mmmf_mixedVAR_simple_janitor_clean_split)
  ) %>% rsample::vfold_cv(v = 5)

## /// changing from xgboost to lightgbm.
# lightgbm model specification
lightgbm_model<- 
  parsnip::boost_tree(
    min_n = tune(), #min_data_in_leaf
    tree_depth = tune(), #max_depth
    trees = tune(), #num_iterations
    learn_rate = tune(), #learning_rate
    loss_reduction = tune(), #min_gain_to_split
    mtry = tune(), #feature_fraction
    sample_size = tune() #bagging_fraction
  ) %>% set_engine("lightgbm") %>% set_mode("regression")

# ///grid specification by dials package to fill in the model above
# grid specification
lightgbm_params <- 
  dials::parameters(
    min_n(),
    tree_depth(),
    trees(),
    learn_rate(),
    loss_reduction(),
    mtry(),
    sample_size = sample_prop()
  ) %>% update(mtry = finalize(mtry(), mmmf_mixedVAR_simple_janitor_clean %>% select(-qf2020scenario))) #
#mtry and sample_size need to be provided with a range of how much to sample in (sample_size) and from how many predictor to select (mtry)
#mtry will use in this annotation any of predictors without the deselected(-)

# ///and the grid to look in 
# Experimental designs for computer experiments are used
# to construct parameter grids that try to cover the parameter space such that
# any portion of the space has an observed combination that is not too far from
# it.
lgbm_grid <- 
  dials::grid_max_entropy(
    lightgbm_params, 
    size = 7
  )
# To tune our model, we perform grid search over our xgboost_grid’s grid space
# to identify the hyperparameter values that have the lowest prediction error.

# Workflow setup
# /// (contains the work)
lgbm_wf <- 
  workflows::workflow() %>%
  add_model(lightgbm_model
  ) %>% 
  add_formula(qf2020scenario ~ .)

# /// so far little to no computation has been performed except for
# /// preprocessing calculations

# Step 7: Tune the Model

# Tuning is where the tidymodels ecosystem of packages really comes together.
# Here is a quick breakdown of the objects passed to the first 4 arguments of
# our call to tune_grid() below:
#
# “object”: xgboost_wf which is a workflow that we defined by the parsnip and
# workflows packages “resamples”: ames_cv_folds as defined by rsample and
# recipes packages “grid”: xgboost_grid our grid space as defined by the dials
# package “metric”: the yardstick package defines the metric set used to
# evaluate model performance
# 
# hyperparameter tuning
# //// this is where the machine starts to smoke!
set_dependency("boost_tree", eng = "lightgbm", "lightgbm")
set_dependency("boost_tree", eng = "lightgbm", "treesnip")
lgbm_tuned <- tune::tune_grid(
  object = lgbm_wf,
  resamples = mmmf_mixedVAR_simple_janitor_clean_split_preproc_cv_folds,
  grid = lgbm_grid,
  metrics = yardstick::metric_set(rmse, rsq, mae),
  control = tune::control_grid(verbose = TRUE)
)
#> Warning: All models failed. See the `.notes` column.

Created on 2022-02-26 by the reprex package (v2.0.1)

Session info ``` r sessionInfo() #> R version 4.1.2 (2021-11-01) #> Platform: x86_64-w64-mingw32/x64 (64-bit) #> Running under: Windows 10 x64 (build 19044) #> #> Matrix products: default #> #> locale: #> [1] LC_COLLATE=German_Germany.1252 LC_CTYPE=German_Germany.1252 #> [3] LC_MONETARY=German_Germany.1252 LC_NUMERIC=C #> [5] LC_TIME=German_Germany.1252 #> #> attached base packages: #> [1] parallel stats graphics grDevices utils datasets methods #> [8] base #> #> other attached packages: #> [1] doParallel_1.0.17 iterators_1.0.14 foreach_1.5.2 #> [4] ggplot2_3.3.5 treesnip_0.1.0.9000 yardstick_0.0.9 #> [7] workflows_0.2.4 dials_0.1.0 scales_1.1.1 #> [10] tune_0.1.6 parsnip_0.1.7 rsample_0.1.1 #> [13] recipes_0.2.0 dplyr_1.0.8 janitor_2.1.0 #> [16] lightgbm_3.3.2 R6_2.5.1 #> #> loaded via a namespace (and not attached): #> [1] tidyr_1.2.0 jsonlite_1.7.3 splines_4.1.2 prodlim_2019.11.13 #> [5] assertthat_0.2.1 highr_0.9 GPfit_1.0-8 yaml_2.3.5 #> [9] globals_0.14.0 ipred_0.9-12 pillar_1.7.0 lattice_0.20-45 #> [13] glue_1.6.1 pROC_1.18.0 digest_0.6.29 snakecase_0.11.0 #> [17] hardhat_0.2.0 colorspace_2.0-2 plyr_1.8.6 htmltools_0.5.2 #> [21] Matrix_1.4-0 timeDate_3043.102 pkgconfig_2.0.3 lhs_1.1.4 #> [25] DiceDesign_1.9 listenv_0.8.0 purrr_0.3.4 gower_1.0.0 #> [29] lava_1.6.10 tibble_3.1.6 generics_0.1.2 ellipsis_0.3.2 #> [33] withr_2.4.3 furrr_0.2.3 nnet_7.3-17 cli_3.2.0 #> [37] survival_3.2-13 magrittr_2.0.2 crayon_1.5.0 evaluate_0.15 #> [41] fs_1.5.2 future_1.24.0 fansi_1.0.2 parallelly_1.30.0 #> [45] MASS_7.3-55 class_7.3-20 tools_4.1.2 data.table_1.14.2 #> [49] lifecycle_1.0.1 stringr_1.4.0 munsell_0.5.0 reprex_2.0.1 #> [53] compiler_4.1.2 rlang_1.0.1 grid_4.1.2 rstudioapi_0.13 #> [57] rmarkdown_2.11 gtable_0.3.0 codetools_0.2-18 DBI_1.1.2 #> [61] lubridate_1.8.0 knitr_1.37 fastmap_1.1.0 future.apply_1.8.1 #> [65] utf8_1.2.2 stringi_1.7.6 Rcpp_1.0.8 vctrs_0.3.8 #> [69] rpart_4.1.16 tidyselect_1.1.2 xfun_0.29 ```
felxcon commented 2 years ago

Hi!

I think the issues lays with bagging_fraction aka. sample_size being "protected"! -> https://github.com/tidymodels/parsnip/issues/136#issuecomment-695845233 (.notes: The following arguments cannot be manually modified and were removed: bagging_fraction.)

For information: setting arguments in lgbm can be accomplished by using "set_mode" (> https://www.rebeccabarter.com/blog/2020-03-25_machine_learning/). strings then should be in " arg ", e.g. tree_learner = "data".

Kindly