I found a bug when using predict() function after I trained a model use train() function and method = "bartMachine, the predicted class and probability is inconsistent.
Minimal reproducible example of error is shown below.
(I create a grid since there is a bug in grid() function for method = "bartMachine", see issue #1278 ).
library(caret)
set.seed(1000)
try.x = matrix(rnorm(20*50,0,1),nrow = 20,ncol = 50)
try.x = as.data.frame(try.x)
try.y = rbinom(20,1,0.5)
# change y to prevent invalid R variable name error
try.y = ifelse(try.y==1,"yes","no")
try.y = factor(try.y,levels = c("yes","no"))
grid = expand.grid(num_trees = seq(1,3)*50,
k = 2,
alpha = 0.95,
beta = 2,
nu = 3.0)
test_bart <- train(try.x,try.y,
method = "bartMachine",
tuneGrid = grid,
verbose = FALSE)
predict(test_bart, head(try.x), type = "prob")
predict(test_bart, head(try.x))
The results of last two lines are:
> predict(test_bart, head(try.x), type = "prob")
predicting probabilities where "yes" is considered the target level...
yes no
1 0.4347876 0.5652124
2 0.5491113 0.4508887
3 0.6392297 0.3607703
4 0.6111202 0.3888798
5 0.2807636 0.7192364
6 0.3348303 0.6651697
> predict(test_bart, head(try.x))
[1] yes no no no yes yes
Levels: yes no
Session Info
R version 4.1.2 (2021-11-01)
Platform: x86_64-w64-mingw32/x64 (64-bit)
Running under: Windows 10 x64 (build 19042)
Hello,
I found a bug when using
predict()
function after I trained a model usetrain()
function andmethod = "bartMachine
, the predicted class and probability is inconsistent.Minimal reproducible example of error is shown below. (I create a grid since there is a bug in
grid()
function formethod = "bartMachine"
, see issue #1278 ).The results of last two lines are:
Session Info
R version 4.1.2 (2021-11-01) Platform: x86_64-w64-mingw32/x64 (64-bit) Running under: Windows 10 x64 (build 19042)
Matrix products: default
locale: [1] LC_COLLATE=Chinese (Simplified)_China.936 LC_CTYPE=Chinese (Simplified)_China.936
[3] LC_MONETARY=Chinese (Simplified)_China.936 LC_NUMERIC=C
[5] LC_TIME=Chinese (Simplified)_China.936
attached base packages: [1] stats graphics grDevices utils datasets methods base
other attached packages: [1] caret_6.0-91 lattice_0.20-45 ggplot2_3.3.5
loaded via a namespace (and not attached): [1] Rcpp_1.0.7 lubridate_1.8.0 listenv_0.8.0 class_7.3-19 digest_0.6.28
[6] ipred_0.9-12 foreach_1.5.2 utf8_1.2.2 parallelly_1.31.0 R6_2.5.1
[11] plyr_1.8.6 hardhat_0.2.0 stats4_4.1.2 e1071_1.7-9 evaluate_0.15
[16] pillar_1.7.0 itertools_0.1-3 rlang_1.0.2 rstudioapi_0.13 data.table_1.14.2
[21] bartMachine_1.2.6 missForest_1.4 rpart_4.1.16 Matrix_1.3-4 rmarkdown_2.13
[26] splines_4.1.2 gower_1.0.0 stringr_1.4.0 munsell_0.5.0 proxy_0.4-26
[31] compiler_4.1.2 xfun_0.30 pkgconfig_2.0.3 globals_0.14.0 htmltools_0.5.2
[36] Metrics_0.1.4 nnet_7.3-17 tidyselect_1.1.2 tibble_3.1.6 prodlim_2019.11.13
[41] bartMachineJARs_1.1 codetools_0.2-18 randomForest_4.7-1 fansi_0.5.0 future_1.24.0
[46] crayon_1.5.1 dplyr_1.0.8 withr_2.5.0 MASS_7.3-55 recipes_0.2.0
[51] ModelMetrics_1.2.2.2 grid_4.1.2 nlme_3.1-153 gtable_0.3.0 lifecycle_1.0.1
[56] magrittr_2.0.1 pROC_1.18.0 scales_1.1.1 future.apply_1.8.1 cli_3.2.0
[61] stringi_1.7.6 reshape2_1.4.4 timeDate_3043.102 ellipsis_0.3.2 generics_0.1.2
[66] vctrs_0.3.8 lava_1.6.10 iterators_1.0.14 tools_4.1.2 glue_1.5.0
[71] purrr_0.3.4 parallel_4.1.2 fastmap_1.1.0 survival_3.3-1 yaml_2.2.1
[76] colorspace_2.0-2 rJava_1.0-6 knitr_1.38