tidymodels / dials

Tools for creating tuning parameter values
https://dials.tidymodels.org/
Other
113 stars 27 forks source link

Request support for scale_pos_weight() in boost_tree() #152

Closed joeycouse closed 3 years ago

joeycouse commented 3 years ago

Feature

I've found the scale_pos_weight feature very useful when using xgb.train() it would be awesome if this parameter could be tuned using the same syntax as mtry() or other boost_tree() tunable parameters. Thanks!

library(tidyverse)
library(tidymodels)
library(mlbench)

data("PimaIndiansDiabetes")

set.seed(24)

df <- PimaIndiansDiabetes %>%
  mutate(diabetes = fct_relevel(diabetes, 'pos'))

xgb_model_1 <- 
  boost_tree(trees = 150,
             tree_depth = 3
             ) %>%
  set_engine('xgboost', scale_pos_weight = 0.01, eval_metric = 'auc') %>%
  set_mode('classification')

xgb_model %>%
  fit(diabetes ~ . , df)

##### Different Value of Scale Positive Weight

xgb_model_2 <- 
  boost_tree(trees = 150,
             tree_depth = 3
             ) %>%
  set_engine('xgboost', scale_pos_weight = 0.1, eval_metric = 'auc') %>%
  set_mode('classification')

xgb_model_2 %>%
  fit(diabetes ~ . , df)
juliasilge commented 3 years ago

Looks like we don't have other engine-specific tuning parameters for the "xgboost" engine set up yet. Are there others we should consider adding with this one?

topepo commented 3 years ago

Yes, we can make dials objects for these and set up the other bits in tune to make this more seamless.

In the meantime, you can tune it by explicitly specifying the grid:

library(tidyverse)
library(tidymodels)
#> ── Attaching packages ────────────────────────────────────── tidymodels 0.1.1 ──
#> ✓ broom     0.7.0           ✓ recipes   0.1.15.9000
#> ✓ dials     0.0.9.9000      ✓ rsample   0.0.8.9000 
#> ✓ infer     0.5.2           ✓ tune      0.1.2.9000 
#> ✓ modeldata 0.1.0           ✓ workflows 0.2.1      
#> ✓ parsnip   0.1.4.9000      ✓ yardstick 0.0.7.9000
#> ── Conflicts ───────────────────────────────────────── tidymodels_conflicts() ──
#> x scales::discard() masks purrr::discard()
#> x dplyr::filter()   masks stats::filter()
#> x recipes::fixed()  masks stringr::fixed()
#> x dplyr::lag()      masks stats::lag()
#> x yardstick::spec() masks readr::spec()
#> x recipes::step()   masks stats::step()
library(mlbench)

data("PimaIndiansDiabetes")

set.seed(24)

df <- PimaIndiansDiabetes %>%
  mutate(diabetes = fct_relevel(diabetes, 'pos'))

xgb_model_1 <- 
  boost_tree(trees = 150,
             tree_depth = 3
  ) %>%
  set_engine('xgboost', scale_pos_weight = tune(), eval_metric = 'auc') %>%
  set_mode('classification')

set.seed(1)
xgb_model_1_res <- 
  tune_grid(xgb_model_1, diabetes ~., resamples = vfold_cv(df),
            grid = tibble(scale_pos_weight = 10^c(-3:-1)))
#> 
#> Attaching package: 'rlang'
#> The following objects are masked from 'package:purrr':
#> 
#>     %@%, as_function, flatten, flatten_chr, flatten_dbl, flatten_int,
#>     flatten_lgl, flatten_raw, invoke, list_along, modify, prepend,
#>     splice
#> 
#> Attaching package: 'vctrs'
#> The following object is masked from 'package:dplyr':
#> 
#>     data_frame
#> The following object is masked from 'package:tibble':
#> 
#>     data_frame
#> 
#> Attaching package: 'xgboost'
#> The following object is masked from 'package:dplyr':
#> 
#>     slice

collect_metrics(xgb_model_1_res)
#> # A tibble: 6 x 7
#>   scale_pos_weight .metric  .estimator  mean     n std_err .config             
#>              <dbl> <chr>    <chr>      <dbl> <int>   <dbl> <chr>               
#> 1            0.001 accuracy binary     0.349    10  0.0137 Preprocessor1_Model1
#> 2            0.001 roc_auc  binary     0.5      10  0      Preprocessor1_Model1
#> 3            0.01  accuracy binary     0.487    10  0.0130 Preprocessor1_Model2
#> 4            0.01  roc_auc  binary     0.814    10  0.0132 Preprocessor1_Model2
#> 5            0.1   accuracy binary     0.698    10  0.0112 Preprocessor1_Model3
#> 6            0.1   roc_auc  binary     0.798    10  0.0181 Preprocessor1_Model3

Created on 2021-01-19 by the reprex package (v0.3.0)

joeycouse commented 3 years ago

Thanks for the help! additionally, support for lamba L2 Regularization on term weights within boost_tree() would be great

github-actions[bot] commented 3 years ago

This issue has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue.