zachmayer / caretEnsemble

caret models all the way down :turtle:
Other
226 stars 75 forks source link

caretStack is reordering the x and y such that y is in ascending order, but is not reordering the wts #234

Open the-tourist- opened 6 years ago

the-tourist- commented 6 years ago

In the following example I create a simple custom caret model so that I can view the x, y and wts values being sent to the model. The easiest thing is to add browser() inside of the custom model, but I am using print statements instead that illustrate what problems it can cause.

In the example below my weights are ascending going from 0 to 1 in steps of 0.01. In theory this should have a random effect on the prediction. But because the y get sorted prior to being passed to the model, but the wts don't, the wts no longer align to the x and y rows, and even more pernicious, in the case below they cause the larger y values to be weighted higher, causing a strong distortion of the apparent weighted mean value of the series.

Minimal, reproducible example:

Minimal dataset:

set.seed(1)

df <- data.frame(x = rnorm(100), w = seq(0.01, 1, length.out = 100))
df$y <- df$x * 0.1 + rnorm(100) * 0.9

head(df)
tail(df)

Minimal, runnable code:

library(caret)
library(caretEnsemble)

# Mean Custom Caret Method
CaretMean <- list (
  library = c("dplyr"),
  type = "Regression",
  parameters = data.frame(parameter = c("None"),
                          class = c("character"),
                          label = c("None")),
  grid = function(x, y, len = NULL, search = "grid") { data.frame( None = "" ) },
  fit = function(x, y, wts, param, lev, last, weights = NA, classProbs = NA, ...) {
    RetVal <- list()
    if (is.null(wts))
      wts <- rep(1, length(y))
    # Both x and y are being resorted such that y is in ascending order, however wts is not reordered.  
    # So the weight no longer corresponds to the correct x and y values, and can cause pernicious problems 
    # such as in this example the weights are also increasing meaning that the weighted mean y value is much 
    # higher than the unweighted mean
    print(sprintf("Unweighted Mean y: %0.2f", mean(y)))
    print(sprintf("Weighted Mean y: %0.2f", sum(y * wts) / sum(wts)))
    # browser()
    class(RetVal) <- "CaretMean"
    return(RetVal)
  },
  predict = function(modelFit, newdata, preProc = NULL, submodels = NULL) {
    sapply(1:nrow(newdata), function(R) mean(newdata[R, ]))
  },
  prob = NULL,
  tags = c("Simple"),
  label = "Mean"
)

models <- caretList(y ~ x, data = df, weights = df$w, trControl = trainControl(method = "cv", savePredictions = "final", allowParallel = F), methodList = c("glm", "gbm", "svmRadialCost", "knn"))

ensemble <- caretStack(models, method = CaretMean, weights = df$w, trControl = trainControl(method = "cv", savePredictions = "final", allowParallel = F))

Session Info:

>sessionInfo()
R version 3.4.1 (2017-06-30)
Platform: x86_64-w64-mingw32/x64 (64-bit)
Running under: Windows Server >= 2012 x64 (build 9200)

Matrix products: default

locale:
[1] LC_COLLATE=English_United States.1252  LC_CTYPE=English_United States.1252    LC_MONETARY=English_United States.1252
[4] LC_NUMERIC=C                           LC_TIME=English_United States.1252    

attached base packages:
 [1] grid      parallel  splines   stats4    stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
 [1] caretEnsemble_2.0.0           caret_6.0-77                  randomForest_4.6-12           data.table_1.10.4            
 [5] weights_0.85                  mice_2.30                     gdata_2.18.0                  flexclust_1.3-4              
 [9] modeltools_0.2-21             magrittr_1.5                  ROI_0.2-6                     PortfolioAnalytics_1.0.3636  
[13] PerformanceAnalytics_1.4.3541 xts_0.10-0                    zoo_1.8-0                     xgboost_0.6-4                
[17] lubridate_1.6.0               bindrcpp_0.2                  GenSA_1.1.6                   optimx_2013.8.7              
[21] doParallel_1.0.10             iterators_1.0.8               glmnet_2.0-13                 foreach_1.4.3                
[25] Matrix_1.2-10                 tidyr_0.7.1                   dplyr_0.7.3                   plyr_1.8.4                   
[29] scales_0.5.0                  car_2.1-5                     MASS_7.3-47                   DBI_0.7                      
[33] rsqlserver_1.0                rClr_0.7-4                    VGAM_1.0-4                    Hmisc_4.0-3                  
[37] ggplot2_2.2.1                 Formula_1.2-2                 survival_2.41-3               lattice_0.20-35              
[41] RODBC_1.3-15                 

loaded via a namespace (and not attached):
 [1] backports_1.1.0           lazyeval_0.2.0            svUnit_0.7-12             BB_2014.10-1              digest_0.6.12            
 [6] htmltools_0.3.6           checkmate_1.8.3           memoise_1.1.0             cluster_2.0.6             recipes_0.1.0            
[11] gower_0.1.2               dimRed_0.1.0              colorspace_1.3-2          lme4_1.1-13               Rglpk_0.6-3              
[16] bindr_0.1                 glue_1.1.1                DRR_0.0.2                 registry_0.3              gtable_0.2.0             
[21] ipred_0.9-6               MatrixModels_0.4-1        kernlab_0.9-25            ddalpha_1.2.1             DEoptimR_1.0-8           
[26] SparseM_1.77              setRNG_2013.9-1           Rcpp_0.12.12              htmlTable_1.9             foreign_0.8-69           
[31] lava_1.5                  prodlim_1.6.1             htmlwidgets_0.9           httr_1.3.1                ROI.plugin.quadprog_0.2-5
[36] RColorBrewer_1.1-2        acepack_1.4.1             pkgconfig_2.0.1           nnet_7.3-12               rlang_0.1.2              
[41] reshape2_1.4.2            munsell_0.4.3             tools_3.4.1               ranger_0.8.0              devtools_1.13.3          
[46] ROI.plugin.glpk_0.2-5     stringr_1.2.0             ModelMetrics_1.1.0        knitr_1.17                robustbase_0.92-7        
[51] purrr_0.2.3               pbapply_1.3-3             nlme_3.1-131              quantreg_5.33             slam_0.1-40              
[56] RcppRoll_0.2.2            compiler_3.4.1            pbkrtest_0.4-7            curl_2.6                  e1071_1.6-8              
[61] tibble_1.3.4              stringi_1.1.5             superpc_1.09              nloptr_1.0.4              gbm_2.1.3                
[66] ucminf_1.1-4              R6_2.2.2                  latticeExtra_0.6-28       gridExtra_2.3             codetools_0.2-15         
[71] gtools_3.5.0              assertthat_0.2.0          CVST_0.2-1                Rvmmin_2017-7.18          optextras_2016-8.8       
[76] withr_2.0.0               mgcv_1.8-17               Rcgmin_2013-2.21          quadprog_1.5-5            dfoptim_2016.7-1         
[81] rpart_4.1-11              timeDate_3012.100         class_7.3-14              minqa_1.2.4               git2r_0.18.0             
[86] numDeriv_2016.8-1         base64enc_0.1-3 
zachmayer commented 6 years ago

Yikes! If you'd like to submit a PR to fix this, I'm happy to take this, otherwise I'll get to it when I have some free time

the-tourist- commented 6 years ago

I've added a PR to fix the issue.