topepo / caret

caret (Classification And Regression Training) R package that contains misc functions for training and plotting classification and regression models
http://topepo.github.io/caret/index.html
1.61k stars 632 forks source link

Bug in `predict()` in `method = "bartMachine"` #1280

Open JiahanChen1 opened 2 years ago

JiahanChen1 commented 2 years ago

Hello,

I found a bug when using predict() function after I trained a model use train() function and method = "bartMachine, the predicted class and probability is inconsistent.

Minimal reproducible example of error is shown below. (I create a grid since there is a bug in grid() function for method = "bartMachine", see issue #1278 ).

library(caret)
set.seed(1000)

try.x = matrix(rnorm(20*50,0,1),nrow = 20,ncol = 50)
try.x = as.data.frame(try.x)

try.y = rbinom(20,1,0.5)
# change y to prevent invalid R variable name error
try.y = ifelse(try.y==1,"yes","no")
try.y = factor(try.y,levels = c("yes","no"))

grid = expand.grid(num_trees = seq(1,3)*50,
                   k = 2,
                   alpha = 0.95,
                   beta = 2,
                   nu = 3.0)

test_bart <- train(try.x,try.y,
                   method = "bartMachine",
                   tuneGrid = grid,
                   verbose = FALSE)

predict(test_bart, head(try.x), type = "prob")
predict(test_bart, head(try.x))

The results of last two lines are:

> predict(test_bart, head(try.x), type = "prob")
predicting probabilities where "yes" is considered the target level...
        yes        no
1 0.4347876 0.5652124
2 0.5491113 0.4508887
3 0.6392297 0.3607703
4 0.6111202 0.3888798
5 0.2807636 0.7192364
6 0.3348303 0.6651697
> predict(test_bart, head(try.x))
[1] yes no  no  no  yes yes
Levels: yes no

Session Info

R version 4.1.2 (2021-11-01) Platform: x86_64-w64-mingw32/x64 (64-bit) Running under: Windows 10 x64 (build 19042)

Matrix products: default

locale: [1] LC_COLLATE=Chinese (Simplified)_China.936 LC_CTYPE=Chinese (Simplified)_China.936
[3] LC_MONETARY=Chinese (Simplified)_China.936 LC_NUMERIC=C
[5] LC_TIME=Chinese (Simplified)_China.936

attached base packages: [1] stats graphics grDevices utils datasets methods base

other attached packages: [1] caret_6.0-91 lattice_0.20-45 ggplot2_3.3.5

loaded via a namespace (and not attached): [1] Rcpp_1.0.7 lubridate_1.8.0 listenv_0.8.0 class_7.3-19 digest_0.6.28
[6] ipred_0.9-12 foreach_1.5.2 utf8_1.2.2 parallelly_1.31.0 R6_2.5.1
[11] plyr_1.8.6 hardhat_0.2.0 stats4_4.1.2 e1071_1.7-9 evaluate_0.15
[16] pillar_1.7.0 itertools_0.1-3 rlang_1.0.2 rstudioapi_0.13 data.table_1.14.2
[21] bartMachine_1.2.6 missForest_1.4 rpart_4.1.16 Matrix_1.3-4 rmarkdown_2.13
[26] splines_4.1.2 gower_1.0.0 stringr_1.4.0 munsell_0.5.0 proxy_0.4-26
[31] compiler_4.1.2 xfun_0.30 pkgconfig_2.0.3 globals_0.14.0 htmltools_0.5.2
[36] Metrics_0.1.4 nnet_7.3-17 tidyselect_1.1.2 tibble_3.1.6 prodlim_2019.11.13
[41] bartMachineJARs_1.1 codetools_0.2-18 randomForest_4.7-1 fansi_0.5.0 future_1.24.0
[46] crayon_1.5.1 dplyr_1.0.8 withr_2.5.0 MASS_7.3-55 recipes_0.2.0
[51] ModelMetrics_1.2.2.2 grid_4.1.2 nlme_3.1-153 gtable_0.3.0 lifecycle_1.0.1
[56] magrittr_2.0.1 pROC_1.18.0 scales_1.1.1 future.apply_1.8.1 cli_3.2.0
[61] stringi_1.7.6 reshape2_1.4.4 timeDate_3043.102 ellipsis_0.3.2 generics_0.1.2
[66] vctrs_0.3.8 lava_1.6.10 iterators_1.0.14 tools_4.1.2 glue_1.5.0
[71] purrr_0.3.4 parallel_4.1.2 fastmap_1.1.0 survival_3.3-1 yaml_2.2.1
[76] colorspace_2.0-2 rJava_1.0-6 knitr_1.38