bgreenwell / pdp

A general framework for constructing partial dependence (i.e., marginal effect) plots from various types machine learning models in R.
http://bgreenwell.github.io/pdp
93 stars 12 forks source link

partial() function wont work with XGBoost's "binary:logitraw" objective function #109

Closed tmbluth closed 2 years ago

tmbluth commented 4 years ago

Partial dependence plots should be able to work on continuous predictions whether the output has no bounds or are probabilities between 0 and 1. When I train XGBoost models using objective="binary:logitraw" it returns an error. This model does return predicted probabilities in other use cases but something in partial() is not letting this happen.

I cannot post company code here but can produce a shell of what part of the code looks like:

DMatrix_train <- xgb.DMatrix(as.matrix(train_set[,inputs]), label=train_set$target,  weight=train_set$weight)
setinfo(DMatrix_train, 'base_margin', train_set$base_margin)

xgb_model <- xgb.train(
    data=DMatrix_train,
    nrounds=500,
    maximize=FALSE,
    params=list(
      booster='gbtree',
      objective='binary:logitraw'
    )
  )

partial(xgb_model, pred.var = "X1",
        plot = TRUE, plot.engine = "ggplot2",
        train = train_set[,inputs])
Error in super_type.xgb.Booster(object) : 
For classification, switch to an objective function that returns the predicted probabilities.
> sessionInfo()
R version 4.0.2 (2020-06-22)
Platform: x86_64-w64-mingw32/x64 (64-bit)
Running under: Windows Server x64 (build 14393)

Matrix products: default

locale:
[1] LC_COLLATE=English_United States.1252  LC_CTYPE=English_United States.1252   
[3] LC_MONETARY=English_United States.1252 LC_NUMERIC=C                          
[5] LC_TIME=English_United States.1252    

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
 [1] Matrix_1.2-18    pdp_0.7.0        odbc_1.2.2       DBI_1.1.0        gbm_2.1.5       
 [6] plotly_4.9.2.1   xgboost_0.90.0.2 forcats_0.5.0    stringr_1.4.0    dplyr_1.0.0     
[11] purrr_0.3.4      readr_1.3.1      tidyr_1.1.0      tibble_3.0.1     ggplot2_3.3.1   
[16] tidyverse_1.3.0 

loaded via a namespace (and not attached):
 [1] httr_1.4.1        bit64_0.9-7       jsonlite_1.6.1    viridisLite_0.3.0
 [5] splines_4.0.2     modelr_0.1.8      assertthat_0.2.1  blob_1.2.1       
 [9] cellranger_1.1.0  yaml_2.2.1        pillar_1.4.4      backports_1.1.7  
[13] lattice_0.20-41   glue_1.4.1        digest_0.6.25     rvest_0.3.5      
[17] colorspace_1.4-1  htmltools_0.4.0   plyr_1.8.6        pkgconfig_2.0.3  
[21] broom_0.5.6       haven_2.3.1       scales_1.1.1      generics_0.0.2   
[25] farver_2.0.3      ellipsis_0.3.1    withr_2.2.0       lazyeval_0.2.2   
[29] cli_2.0.2         survival_3.1-12   magrittr_1.5      crayon_1.3.4     
[33] readxl_1.3.1      evaluate_0.14     fs_1.4.1          fansi_0.4.1      
[37] nlme_3.1-148      xml2_1.3.2        tools_4.0.2       data.table_1.12.8
[41] hms_0.5.3         lifecycle_0.2.0   munsell_0.5.0     reprex_0.3.0     
[45] compiler_4.0.2    tinytex_0.23      rlang_0.4.6       grid_4.0.2       
[49] rstudioapi_0.11   htmlwidgets_1.5.1 crosstalk_1.1.0.1 labeling_0.3     
[53] rmarkdown_2.2     gtable_0.3.0      R6_2.4.1          gridExtra_2.3    
[57] lubridate_1.7.9   knitr_1.28        keyring_1.1.0     bit_1.1-15.2     
[61] utf8_1.1.4        stringi_1.4.6     Rcpp_1.0.4.6      vctrs_0.3.1      
[65] dbplyr_1.4.4      tidyselect_1.1.0  xfun_0.14
bgreenwell commented 4 years ago

Hi @tmbluth, thanks for pointing out the issue (binary:logitraw was not an option when I originally added XGBoost support). Should be an easy fix (but it might be awhile before I get around to it). Until then you have two workarounds. The easiest is to specify type = "regression" in the call to partial() (this tricks it into working):

data(spam, package = "kernlab")

X <- data.matrix(subset(spam, select = -type))
y <- ifelse(spam$type == "spam", 1, 0)

bst <- xgboost(data = X, label = y, max.depth = 3, eta = 0.1, nrounds = 100, 
               objective = "binary:logitraw")

pdp::partial(bst, pred.var = "charExclamation", train = X, plot = TRUE)  # error

pdp::partial(bst, pred.var = "charExclamation", train = X, plot = TRUE,  # success
             type = "regression")

The other option, which is always more flexible, is to provide your own prediction wrapper via the pred.fun argument. Examples and details are given in the docs and corresponding R Journal article: https://journal.r-project.org/archive/2017/RJ-2017-016/index.html.

tmbluth commented 4 years ago

That worked, thank you! Looking forward to the official fix

bgreenwell commented 4 years ago

Related to this issue: https://github.com/bgreenwell/pdp/issues/99.

And also this issue: https://github.com/bgreenwell/pdp/issues/68.