tidymodels / parsnip

A tidy unified interface to models
https://parsnip.tidymodels.org
Other
586 stars 88 forks source link

Account for possibility of custom objective function in XGBoost `boost_tree()` #459

Closed smingerson closed 2 years ago

smingerson commented 3 years ago

Currently, passing a custom objective function causes an error downstream when predicting. This happens in xgb_pred() when using switch() off of the objective (usually a character string) to modify the output of predict.xgb.Booster().

library(xgboost)
library(parsnip)
library(workflows)
mod <- boost_tree("regression") %>% 
  set_engine("xgboost",
             objective = function(preds, dtrain) {
               truth <- as.numeric(getinfo(dtrain, "label"))
               error <- truth - preds
               gradient <- -2 * error
               hess <- rep.int(2, length(preds))
               list(grad = gradient, hess = hess)
             }
             )

dt <- data.frame(x = rnorm(15))
dt$y <- dt$x + rnorm(15, 0, .05)

wf <- workflow() %>% 
  add_model(mod) %>% 
  add_formula(y~x)
fitted <- fit(wf, data = dt)
predict(fitted, new_data = dt)
#> Error in switch(object$params$objective, `binary:logitraw` = stats::binomial()$linkinv(res), : EXPR must be a length 1 vector
jcpsantiago commented 3 years ago

I also see this error when using parsnip::set_engine("xgboost", params = list(eval_metric = "aucpr")) without setting the objective argument. I came across this error after updating parsnip to 0.1.5 from 0.1.4, and tune::tune_grid started failing. (tidymodels and the other individual packages were also updated in that time i.e. {workflows}, {tune}).

this test is passing: https://github.com/tidymodels/parsnip/blob/cb086385a90227eacfce2f06ed58ff2d4e17bb29/tests/testthat/test_boost_tree_xgboost.R#L169

  spec <-
    boost_tree() %>%
    set_engine("xgboost", objective = "reg:pseudohubererror") %>%
    set_mode("regression")

  xgb_fit <- spec %>% fit(mpg ~ ., data = mtcars)

but if objective is not a string (I guess this is the reason for labelling this as a feature request instead of a bug) it fails like in OPs code. Additionally, if one adds anything else to set_engine it fails with the same error -- are the ... all added to the same vector?

library(xgboost)
library(parsnip)
library(workflows)

mod <- boost_tree("classification") %>% 
  set_engine(
    "xgboost", 
    objective = "binary:logistic",
    params = list(eval_metric = "aucpr") # <- added this and changed the data to be a classification problem
  )

dt <- data.frame(
  x = rnorm(15),
  y = rnorm(15) + rnorm(15, 0, .05),
  target = as.factor(rbinom(15, 1, 0.5))
)

wf <- workflow() %>% 
  add_model(mod) %>% 
  add_formula(target ~ x + y)

fitted <- fit(wf, data = dt)
predict(fitted, new_data = dt)
#> Error in switch(object$params$objective, `binary:logitraw` = stats::binomial()$linkinv(res), : EXPR must be a length 1 vector

Created on 2021-04-20 by the reprex package (v2.0.0)

Session info ``` r sessioninfo::session_info() #> ─ Session info ─────────────────────────────────────────────────────────────── #> setting value #> version R version 4.0.2 (2020-06-22) #> os macOS 10.16 #> system x86_64, darwin17.0 #> ui X11 #> language (EN) #> collate en_US.UTF-8 #> ctype en_US.UTF-8 #> tz Europe/Berlin #> date 2021-04-20 #> #> ─ Packages ─────────────────────────────────────────────────────────────────── #> ! package * version date lib source #> assertthat 0.2.1 2019-03-21 [1] CRAN (R 4.0.2) #> P cli 2.4.0 2021-04-05 [?] CRAN (R 4.0.2) #> P codetools 0.2-18 2020-11-04 [3] CRAN (R 4.0.2) #> P crayon 1.4.1 2021-02-08 [?] CRAN (R 4.0.2) #> P data.table 1.14.0 2021-02-21 [?] CRAN (R 4.0.2) #> P DBI 1.1.1 2021-01-15 [?] CRAN (R 4.0.2) #> digest 0.6.27 2020-10-24 [1] CRAN (R 4.0.2) #> P dplyr 1.0.5 2021-03-05 [?] CRAN (R 4.0.2) #> ellipsis 0.3.1 2020-05-15 [1] CRAN (R 4.0.2) #> P evaluate 0.14 2019-05-28 [?] CRAN (R 4.0.0) #> P fansi 0.4.2 2021-01-15 [?] CRAN (R 4.0.2) #> P fs 1.5.0 2020-07-31 [?] CRAN (R 4.0.2) #> P generics 0.1.0 2020-10-31 [?] CRAN (R 4.0.2) #> globals 0.14.0 2020-11-22 [1] CRAN (R 4.0.2) #> glue 1.4.2 2020-08-27 [1] CRAN (R 4.0.2) #> P hardhat 0.1.5 2020-11-09 [?] CRAN (R 4.0.2) #> P highr 0.9 2021-04-16 [?] CRAN (R 4.0.2) #> P htmltools 0.5.1.1 2021-01-22 [?] CRAN (R 4.0.2) #> P knitr 1.32 2021-04-14 [?] CRAN (R 4.0.2) #> P lattice 0.20-41 2020-04-02 [3] CRAN (R 4.0.2) #> P lifecycle 1.0.0 2021-02-15 [?] CRAN (R 4.0.2) #> magrittr 2.0.1 2020-11-17 [1] CRAN (R 4.0.2) #> P Matrix 1.3-2 2021-01-06 [?] CRAN (R 4.0.2) #> P parsnip * 0.1.5 2021-01-19 [?] CRAN (R 4.0.2) #> P pillar 1.6.0 2021-04-13 [?] CRAN (R 4.0.2) #> pkgconfig 2.0.3 2019-09-22 [1] CRAN (R 4.0.2) #> purrr 0.3.4 2020-04-17 [1] CRAN (R 4.0.2) #> R6 2.5.0 2020-10-28 [1] CRAN (R 4.0.2) #> P reprex 2.0.0 2021-04-02 [?] CRAN (R 4.0.2) #> P rlang 0.4.10 2020-12-30 [?] CRAN (R 4.0.2) #> P rmarkdown 2.7 2021-02-19 [?] CRAN (R 4.0.2) #> rstudioapi 0.13 2020-11-12 [1] CRAN (R 4.0.2) #> sessioninfo 1.1.1 2018-11-05 [3] CRAN (R 4.0.2) #> stringi 1.5.3 2020-09-09 [1] CRAN (R 4.0.2) #> stringr 1.4.0 2019-02-10 [1] CRAN (R 4.0.2) #> P tibble 3.1.1 2021-04-18 [?] CRAN (R 4.0.2) #> P tidyr 1.1.3 2021-03-03 [?] CRAN (R 4.0.2) #> tidyselect 1.1.0 2020-05-11 [1] CRAN (R 4.0.2) #> P utf8 1.2.1 2021-03-12 [?] CRAN (R 4.0.2) #> P vctrs 0.3.7 2021-03-29 [?] CRAN (R 4.0.2) #> P withr 2.4.2 2021-04-18 [?] CRAN (R 4.0.2) #> P workflows * 0.2.2 2021-03-10 [?] CRAN (R 4.0.2) #> P xfun 0.22 2021-03-11 [?] CRAN (R 4.0.2) #> xgboost * 1.3.2.1 2021-01-18 [1] CRAN (R 4.0.2) #> yaml 2.2.1 2020-02-01 [1] CRAN (R 4.0.2) #> #> [1] /Users/santiago/code/ds-models-fraud/renv/library/R-4.0/x86_64-apple-darwin17.0 #> [2] /private/var/folders/8d/zxgx1qkx44n7_wp6crx3ycsh0000gn/T/Rtmp6h44Di/renv-system-library #> [3] /Library/Frameworks/R.framework/Versions/4.0/Resources/library #> #> P ── Loaded and on-disk path mismatch. ```

To fix it I had to change my code to:

mod <- boost_tree("classification") %>% 
  set_engine(
    "xgboost",
    params = list(
      eval_metric = "aucpr",
      objective = "binary:logistic" # <- MUST be present
    )
  )

the objective must be explicitly declared if params is used, otherwise object$params$objective is NULL. Not sure if this is expected behavior i.e. the default was dropped.

amazongodman commented 2 years ago

There are similar reports here as well.

https://github.com/tidymodels/butcher/issues/214

simonpcouch commented 2 years ago

Related to #774.

github-actions[bot] commented 2 years ago

This issue has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue.