topepo / caret

caret (Classification And Regression Training) R package that contains misc functions for training and plotting classification and regression models
http://topepo.github.io/caret/index.html
1.61k stars 633 forks source link

Recipe vs Formula vs X/Y Interface reproducibility for gbm. Potential bug? #1139

Open Jack-FG opened 4 years ago

Jack-FG commented 4 years ago

I have trained the same model on the iris data set to investigate the reproducibility of each method. It seems that there is a discrepency between models when using all.equals() for the models trained with the recipes interface.

library(plyr)
library(tidyverse)
library(gbm)
library(caret)
library(recipes)

# recipe to be supplied
Recipe.Obj <- recipe(Sepal.Length ~ ., data = iris)

# train control object
TC.Obj <- trainControl("cv", savePredictions = "all", summaryFunction = defaultSummary, returnResamp = "all")

Model = "gbm"
Recipe = Recipe.Obj
TC = TC.Obj
Training.Data.Set = iris
Metric = "RMSE"

# Using a recipe object

set.seed(0)
Model.Obj.1 <- train(Recipe,
                     method = Model,
                     data = Training.Data.Set,
                     trControl = TC,
                     metric = Metric,
                     verbose = FALSE,
                     tuneLength = 3
)

set.seed(0)
Model.Obj.2 <- train(Recipe,
                     method = Model,
                     data = Training.Data.Set,
                     trControl = TC,
                     metric = Metric,
                     verbose = FALSE,
                     tuneLength = 3
)

# does not return equal objects
all.equal(Model.Obj.1, Model.Obj.2)

# Using formula

set.seed(0)
Model.Obj.3 <- train(Sepal.Length ~ .,
                     method = Model,
                     data = Training.Data.Set,
                     trControl = TC,
                     metric = Metric,
                     verbose = FALSE,
                     tuneLength = 3
)

set.seed(0)
Model.Obj.4 <- train(Sepal.Length ~ .,
                     method = Model,
                     data = Training.Data.Set,
                     trControl = TC,
                     metric = Metric,
                     verbose = FALSE,
                     tuneLength = 3
)

#returns equal objects except for times
all.equal(Model.Obj.3, Model.Obj.4)

# Using x/y

set.seed(0)
Model.Obj.5 <- train(Training.Data.Set[,-1],Training.Data.Set[,1],
                     method = Model,
                     trControl = TC,
                     metric = Metric,
                     verbose = FALSE,
                     tuneLength = 3
)

set.seed(0)
Model.Obj.6 <- train(Training.Data.Set[,-1], Training.Data.Set[,1],
                     method = Model,
                     trControl = TC,
                     metric = Metric,
                     verbose = FALSE,
                     tuneLength = 3
)
#returns equal objects except for times
all.equal(Model.Obj.5, Model.Obj.6)

Session Info:

sessionInfo()


R version 3.5.2 (2018-12-20)
Platform: x86_64-w64-mingw32/x64 (64-bit)
Running under: Windows 10 x64 (build 18362)

Matrix products: default

locale: [1] LC_COLLATE=English_Australia.1252 LC_CTYPE=English_Australia.1252 LC_MONETARY=English_Australia.1252 LC_NUMERIC=C
[5] LC_TIME=English_Australia.1252

attached base packages: [1] stats graphics grDevices utils datasets methods base

other attached packages: [1] recipes_0.1.11 caret_6.0-86 lattice_0.20-38 gbm_2.1.5 forcats_0.4.0 stringr_1.4.0 dplyr_0.8.4 purrr_0.3.3 readr_1.3.1
[10] tidyr_1.0.2 tibble_2.1.3 ggplot2_3.2.1 tidyverse_1.3.0 plyr_1.8.4

loaded via a namespace (and not attached): [1] Rcpp_1.0.1 lubridate_1.7.4 prettyunits_1.0.2 ps_1.3.0 class_7.3-14 assertthat_0.2.1
[7] packrat_0.5.0 ipred_0.9-8 foreach_1.4.4 R6_2.4.0 cellranger_1.1.0 backports_1.1.4
[13] stats4_3.5.2 reprex_0.3.0 httr_1.4.1 pillar_1.4.3 rlang_0.4.5 lazyeval_0.2.1
[19] readxl_1.3.1 data.table_1.11.8 rstudioapi_0.11 callr_3.4.3 rpart_4.1-13 Matrix_1.2-15
[25] splines_3.5.2 gower_0.2.0 munsell_0.5.0 broom_0.5.4 compiler_3.5.2 modelr_0.1.6
[31] pkgconfig_2.0.2 pkgbuild_1.0.6.9000 nnet_7.3-12 tidyselect_0.2.5 prodlim_2018.04.18 gridExtra_2.3
[37] codetools_0.2-15 fansi_0.4.0 crayon_1.3.4 dbplyr_1.4.2 withr_2.1.2 ModelMetrics_1.2.2.2 [43] MASS_7.3-51.5 grid_3.5.2 nlme_3.1-137 jsonlite_1.6.1 gtable_0.2.0 lifecycle_0.1.0
[49] DBI_1.0.0 magrittr_1.5 pROC_1.13.0 scales_1.0.0 cli_2.0.2 stringi_1.3.1
[55] reshape2_1.4.3 fs_1.3.1 timeDate_3043.102 xml2_1.2.2 generics_0.0.2 vctrs_0.2.3
[61] lava_1.6.5 iterators_1.0.10 tools_3.5.2 glue_1.4.0 hms_0.5.3 processx_3.4.1
[67] survival_2.43-3 colorspace_1.4-0 rvest_0.3.5 haven_2.2.0

Jack-FG commented 4 years ago

This is potentially better off being asked in recipes github?