topepo / caret

caret (Classification And Regression Training) R package that contains misc functions for training and plotting classification and regression models
http://topepo.github.io/caret/index.html
1.62k stars 632 forks source link

custom summary in trainControl #719

Open franperezlopez opened 7 years ago

franperezlopez commented 7 years ago

Minimal, reproducible example:

Working with custom summaries, I found the call to getTrainPerf() does not return all new columns defined in the custom summary function. But taking a look into the results and resample, those fields are present there. Executing the training several times, the missing field is different. Can you flag any error in my code??

thanks

Minimal dataset:

training <- SLC14_1(30) %>% bind_cols(weight=sample(1:4,30,T))
trainX <- training[, -21]
trainY <- training$y

Minimal, runnable code:

rec_reg <- recipe(y ~ ., data = training) %>%
  add_role(weight, new_role = "performance var") %>%
  step_center(all_predictors()) %>%
  step_scale(all_predictors()) 

defaultSummaryExtended <- function (data, lev = NULL, model = NULL) 
{
  require(reshape2)
  require(ModelMetrics)

  weights_summary <- data.frame(obs=data[,"obs"], 
             pred=data[,"pred"], 
             weight=data[,"weight"]) %>%
    group_by(weight) %>% 
    summarise(rmse = rmse(obs,pred), lbl=paste0("wRMSE_",first(weight))) %>% 
    dcast(rmse ~ lbl, value.var = "rmse") %>% 
    select (-rmse) %>% 
    summarise_all(sum, na.rm=T) %>%
    as.list()

  c(defaultSummary(data, lev, model), unlist(weights_summary))
}

rctrl1 <- trainControl(method = "cv", number = 3, returnResamp = "all", summaryFunction = defaultSummaryExtended)

test_enet_rec <- train(recipe = rec_reg,
                      data = training,
                      method = "enet", 
                      trControl = rctrl1)

getTrainPerf(test_enet_rec) # works, but does not return 4 wRMSE columns (randomly miss any one)
  TrainRMSE TrainRsquared TrainMAE TrainwRMSE_2 TrainwRMSE_3 TrainwRMSE_4 method
1  24.69747     0.1235343 19.86194     31.69428     17.24725     12.82436   enet
test_enet_rec$perfNames # I think this field is not set correctly
"RMSE"     "Rsquared" "MAE"      "wRMSE_2"  "wRMSE_3"  "wRMSE_4"
names(test_enet_rec$resample)
 [1] "lambda"   "fraction" "RMSE"     "Rsquared" "MAE"      "wRMSE_1"  "wRMSE_2"  "wRMSE_3" 
 [9] "wRMSE_4"  "Resample"

Session Info:

>sessionInfo()
R version 3.4.1 (2017-06-30)
Platform: x86_64-w64-mingw32/x64 (64-bit)
Running under: Windows >= 8 x64 (build 9200)

Matrix products: default

locale:
[1] LC_COLLATE=English_United Kingdom.1252  LC_CTYPE=English_United Kingdom.1252   
[3] LC_MONETARY=English_United Kingdom.1252 LC_NUMERIC=C                           
[5] LC_TIME=English_United Kingdom.1252    

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
 [1] reshape2_1.4.2     ModelMetrics_1.1.0 bindrcpp_0.2       elasticnet_1.1     lars_1.2          
 [6] recipes_0.1.0      caret_6.0-77       ggplot2_2.2.1      lattice_0.20-35    dplyr_0.7.2       
[11] plyr_1.8.4        

loaded via a namespace (and not attached):
 [1] Rcpp_0.12.12      ddalpha_1.2.1     compiler_3.4.1    DEoptimR_1.0-8    gower_0.1.2      
 [6] bindr_0.1         class_7.3-14      iterators_1.0.8   tools_3.4.1       rpart_4.1-11     
[11] ipred_0.9-6       lubridate_1.6.0   tibble_1.3.3      nlme_3.1-131      gtable_0.2.0     
[16] pkgconfig_2.0.1   rlang_0.1.2       Matrix_1.2-10     foreach_1.4.3     RcppRoll_0.2.2   
[21] prodlim_1.6.1     knitr_1.17        withr_2.0.0       stringr_1.2.0     tidyselect_0.1.1 
[26] stats4_3.4.1      nnet_7.3-12       CVST_0.2-1        grid_3.4.1        robustbase_0.92-7
[31] glue_1.1.1        R6_2.2.2          survival_2.41-3   lava_1.5          purrr_0.2.3      
[36] kernlab_0.9-25    magrittr_1.5      DRR_0.0.2         splines_3.4.1     scales_0.4.1     
[41] codetools_0.2-15  MASS_7.3-47       assertthat_0.2.0  dimRed_0.1.0      timeDate_3012.100
[46] colorspace_1.3-2  stringi_1.1.5     lazyeval_0.2.0    munsell_0.4.3    

You can delete the text in each section that explains how to do it correctly. Be sure to test your 2 chunks of code in an empty R session before submitting your issue!

topepo commented 7 years ago

I don't get the same issue on the current devel version of caret. Try it and see if you get the same results. Note that the first argument of the recipe method has (unfortunately) changed for S3 method consistency.

> library(caret)
Loading required package: lattice
Loading required package: ggplot2
> library(recipes)
Loading required package: dplyr

Attaching package: ‘dplyr’

The following objects are masked from ‘package:stats’:

    filter, lag

The following objects are masked from ‘package:base’:

    intersect, setdiff, setequal, union

Attaching package: ‘recipes’

The following object is masked from ‘package:stats’:

    step

> 
> training <- SLC14_1(30) %>% bind_cols(weight=sample(1:4,30,T))
> trainX <- training[, -21]
> trainY <- training$y
> 
> rec_reg <- recipe(y ~ ., data = training) %>%
+   add_role(weight, new_role = "performance var") %>%
+   step_center(all_predictors()) %>%
+   step_scale(all_predictors()) 
Warning message:
Changing role(s) for weight 
> 
> defaultSummaryExtended <- function (data, lev = NULL, model = NULL) 
+ {
+   require(reshape2)
+   require(ModelMetrics)
+   
+   weights_summary <- data.frame(obs=data[,"obs"], 
+                                 pred=data[,"pred"], 
+                                 weight=data[,"weight"]) %>%
+     group_by(weight) %>% 
+     summarise(rmse = rmse(obs,pred), lbl=paste0("wRMSE_",first(weight))) %>% 
+     dcast(rmse ~ lbl, value.var = "rmse") %>% 
+     select (-rmse) %>% 
+     summarise_all(sum, na.rm=T) %>%
+     as.list()
+   
+   c(defaultSummary(data, lev, model), unlist(weights_summary))
+ }
> 
> rctrl1 <- trainControl(method = "cv", number = 3, returnResamp = "all", summaryFunction = defaultSummaryExtended)
> 
> # note the first argument name has changed =[
> test_enet_rec <- train(x = rec_reg,
+                        data = training,
+                        method = "enet", 
+                        trControl = rctrl1)
Loading required namespace: elasticnet
Loading required package: reshape2
Loading required package: ModelMetrics

Attaching package: ‘ModelMetrics’

The following objects are masked from ‘package:caret’:

    confusionMatrix, precision, recall, sensitivity, specificity

Loading required package: lars
Loaded lars 1.2

> 
> getTrainPerf(test_enet_rec)
  TrainRMSE TrainRsquared TrainMAE TrainwRMSE_1 TrainwRMSE_2 TrainwRMSE_3 TrainwRMSE_4
1  15.42666     0.1919126 11.79471     15.36306     15.79701      13.4199     11.72029
  method
1   enet
> 
> library(sessioninfo)
> session_info()
─ Session info ──────────────────────────────────────────────────────────────────────────
 setting  value                       
 version  R version 3.3.3 (2017-03-06)
 os       macOS Sierra 10.12.6        
 system   x86_64, darwin13.4.0        
 ui       RStudio                     
 language (EN)                        
 collate  en_US.UTF-8                 
 tz       America/New_York            
 date     2017-08-21                  

─ Packages ──────────────────────────────────────────────────────────────────────────────
 package      * version  date       source                        
 assertthat     0.2.0    2017-04-11 CRAN (R 3.3.2)                
 bindr          0.1      2016-11-13 CRAN (R 3.3.2)                
 bindrcpp     * 0.2      2017-06-17 cran (@0.2)                   
 caret        * 6.0-77   2017-08-21 local (@6.0-77)               
 class          7.3-14   2015-08-30 CRAN (R 3.3.3)                
 clisymbols     1.2.0    2017-05-21 CRAN (R 3.3.2)                
 codetools      0.2-15   2016-10-05 CRAN (R 3.3.3)                
 colorspace     1.3-2    2016-12-14 CRAN (R 3.3.2)                
 CVST           0.2-1    2013-12-10 CRAN (R 3.3.0)                
 ddalpha        1.2.1    2016-10-10 CRAN (R 3.3.0)                
 DEoptimR       1.0-8    2016-11-19 CRAN (R 3.3.2)                
 dimRed         0.1.0    2017-05-04 CRAN (R 3.3.2)                
 dplyr        * 0.7.2    2017-07-20 cran (@0.7.2)                 
 DRR            0.0.2    2016-09-15 CRAN (R 3.3.0)                
 elasticnet   * 1.1      2012-06-28 CRAN (R 3.3.0)                
 foreach        1.4.3    2015-10-13 CRAN (R 3.3.0)                
 ggplot2      * 2.2.1    2016-12-30 CRAN (R 3.3.2)                
 glue           1.1.1    2017-06-21 CRAN (R 3.3.2)                
 gower          0.1.2    2017-02-23 CRAN (R 3.3.2)                
 gtable         0.2.0    2016-02-26 CRAN (R 3.3.0)                
 ipred          0.9-6    2017-03-01 cran (@0.9-6)                 
 iterators      1.0.8    2015-10-13 CRAN (R 3.3.0)                
 kernlab        0.9-25   2016-10-03 CRAN (R 3.3.0)                
 lars         * 1.2      2013-04-24 CRAN (R 3.3.0)                
 lattice      * 0.20-35  2017-03-25 CRAN (R 3.3.3)                
 lava           1.5      2017-03-16 cran (@1.5)                   
 lazyeval       0.2.0    2016-06-12 CRAN (R 3.3.0)                
 lubridate      1.6.0    2016-09-13 CRAN (R 3.3.0)                
 magrittr       1.5      2014-11-22 CRAN (R 3.3.0)                
 MASS           7.3-47   2017-04-21 CRAN (R 3.3.3)                
 Matrix         1.2-8    2017-01-20 CRAN (R 3.3.3)                
 ModelMetrics * 1.1.0    2016-08-26 CRAN (R 3.3.0)                
 munsell        0.4.3    2016-02-13 CRAN (R 3.3.0)                
 nlme           3.1-131  2017-02-06 CRAN (R 3.3.3)                
 nnet           7.3-12   2016-02-02 CRAN (R 3.3.3)                
 pkgconfig      2.0.1    2017-03-21 cran (@2.0.1)                 
 plyr           1.8.4    2016-06-08 CRAN (R 3.3.0)                
 prodlim        1.6.1    2017-03-06 cran (@1.6.1)                 
 purrr          0.2.3    2017-08-02 cran (@0.2.3)                 
 R6             2.2.2    2017-06-17 cran (@2.2.2)                 
 Rcpp           0.12.12  2017-07-15 cran (@0.12.12)               
 RcppRoll       0.2.2    2015-04-05 CRAN (R 3.3.0)                
 recipes      * 0.1.0    2017-08-16 local (topepo/recipes@25a05ef)
 reshape2     * 1.4.2    2016-10-22 CRAN (R 3.3.3)                
 rlang          0.1.2    2017-08-09 cran (@0.1.2)                 
 robustbase     0.92-7   2016-12-09 CRAN (R 3.3.2)                
 rpart          4.1-11   2017-04-21 CRAN (R 3.3.3)                
 scales         0.4.1    2016-11-09 CRAN (R 3.3.2)                
 sessioninfo  * 1.0.0    2017-06-21 CRAN (R 3.3.2)                
 stringi        1.1.5    2017-04-07 CRAN (R 3.3.2)                
 stringr        1.2.0    2017-02-18 CRAN (R 3.3.2)                
 survival       2.40-1   2016-10-30 CRAN (R 3.3.3)                
 tibble         1.3.3    2017-05-28 CRAN (R 3.3.2)                
 tidyselect     0.1.1    2017-07-24 CRAN (R 3.3.2)                
 timeDate       3012.100 2015-01-23 cran (@3012.10)               
 withr          2.0.0    2017-07-28 CRAN (R 3.3.2)   
franperezlopez commented 7 years ago

I updated my library to the master branch, but I'm still suffering the same issue

somehow, I managed to get more hints about the issue. You can reproduce without recipes, and is related with the summary function "not returning always the same fields" ... I mean, I realised sometimes in the resample there is no weight X, so the wRMSE_X field is not returned. If the wRMSE_X field returns NA or NaN, the column is always present in the summaries. But the problem exists if the field is not returned in some of the resampled summaries.

This is my updated code

training <- SLC14_1(30) %>% bind_cols(weight=sample(1:4,30,T))
trainX <- training[, -21]
trainY <- training$y

defaultSummaryExtended <- function (data, lev = NULL, model = NULL) 
{
# simulating not returning some fields randomly
  weights_summary <- 
    switch(sample(1:4,1),
           list(wRMSE_1=sample(0:9,1)),
           list(wRMSE_1=sample(0:9,1), wRMSE_2=sample(0:9,1)),
           list(wRMSE_1=sample(0:9,1), wRMSE_2=sample(0:9,1), wRMSE_3=sample(0:9,1)),
           list(wRMSE_1=sample(0:9,1), wRMSE_2=sample(0:9,1), wRMSE_3=sample(0:9,1), wRMSE_4=sample(0:9,1)))

  cat("wRMSE_1:",weights_summary$wRMSE_1,
      ", wRMSE_2:", weights_summary$wRMSE_2, 
      ", wRMSE_3:", weights_summary$wRMSE_3,
      " ,wRMSE_4:", weights_summary$wRMSE_4,"\n")
  c(defaultSummary(data, lev, model), unlist(weights_summary))
}

rctrl1 <- trainControl(method = "cv", number = 3, returnResamp = "final", summaryFunction = defaultSummaryExtended)

test_enet_rec <- train(x = trainX,
                       y = trainY,
                      method = "enet", 
                      trControl = rctrl1,
                      tuneGrid = expand.grid(fraction = 0.05, lambda = 0.1))

I wasn't able to reproduce the bug pattern, but if you reproduce many times former training, you will eventually get something similar to this:

> test_enet_rec <- train(x = trainX,
+                        y = trainY,
+                        method = "enet", 
+                        trControl = rctrl1,
+                        tuneGrid = expand.grid(fraction = 0.05, lambda = 0.1))
wRMSE_1: 5 , wRMSE_2: 4 , wRMSE_3:  ,wRMSE_4: 
wRMSE_1: 2 , wRMSE_2: , wRMSE_3:  ,wRMSE_4: 
wRMSE_1: 5 , wRMSE_2: 5 , wRMSE_3: 2  ,wRMSE_4: 
wRMSE_1: 2 , wRMSE_2: 5 , wRMSE_3: 1  ,wRMSE_4: 8 
Warning message:
In nominalTrainWorkflow(x = x, y = y, wts = weights, info = trainInfo,  :
  There were missing values in resampled performance measures.
> test_enet_rec$resample
      RMSE   Rsquared      MAE wRMSE_1 Resample wRMSE_2 wRMSE_3 wRMSE_4
1 13.16186 0.07013545 11.65727       2    Fold1      NA      NA      NA
2 17.88982 0.47395347 14.36008       5    Fold2       5       2      NA
3 15.28916 0.06187687 11.39220       2    Fold3       5       1       8
> getTrainPerf(test_enet_rec)
  TrainRMSE TrainRsquared TrainMAE TrainwRMSE_1 TrainwRMSE_2 method
1  15.44695     0.2019886 12.46985            3            5   enet
> test_enet_rec$perfNames
[1] "RMSE"     "Rsquared" "MAE"      "wRMSE_1"  "wRMSE_2" 
> test_enet_rec$results
  fraction lambda     RMSE  Rsquared      MAE wRMSE_1 wRMSE_2 wRMSE_3 wRMSE_4  RMSESD RsquaredSD
1     0.05    0.1 15.44695 0.2019886 12.46985       3       5     1.5       8 2.36793  0.2355647
     MAESD wRMSE_1SD wRMSE_2SD wRMSE_3SD wRMSE_4SD
1 1.642346  1.732051         0 0.7071068        NA

If you do always return same fields in the summary, usually resample's column Resample comes last one and perfNames returns all summary fields; but in the former example, you can see Resample column is oddly between other summary fields, I think this is the effect of filling an structure progressively as data is coming. About the session info, I found you're running R 3.3, and my installation is 3.4. Maybe this is the culprit, but I can't downgrade my environment. So I think I will use results field, as it returns same info as getTrainPerf() but it looks more reliable.

thank you for your support and patience reading my issues