topepo / caret

caret (Classification And Regression Training) R package that contains misc functions for training and plotting classification and regression models
http://topepo.github.io/caret/index.html
1.61k stars 634 forks source link

Changed default grid specification for xbgTree #615

Closed RobertKomara closed 7 years ago

RobertKomara commented 7 years ago

Hi all, the default grid for the xbgTree model of the current dev version 6.0-75 is specified differently than the active Cran release 6.0-73. In version 73, _min_childweight and gamma and the only hyperparameters that are held constant, while only the _maxdepth hyperparameter is varied from 1-3 in the dev version. Is this a planned operational change or an unintended effect?

Caret xgbTree Test

Running Cran Caret 6.0-73

library(caret)
## Warning: package 'caret' was built under R version 3.3.3

## Loading required package: lattice

## Loading required package: ggplot2
set.seed(1)
dat <- twoClassSim(100)
X <- dat[,1:5]
y <- dat[["Class"]]

model_class <- train(
  X, y, method='xgbTree',
  metric='ROC',
  trControl=trainControl(
    method="cv", 
    number=5,
    classProbs=TRUE, 
    summaryFunction=twoClassSummary)
)
## Loading required package: xgboost

## Loading required package: plyr
print(model_class)
## eXtreme Gradient Boosting 
## 
## 100 samples
##   5 predictor
##   2 classes: 'Class1', 'Class2' 
## 
## No pre-processing
## Resampling: Cross-Validated (5 fold) 
## Summary of sample sizes: 80, 80, 81, 80, 79 
## Resampling results across tuning parameters:
## 
##   eta  max_depth  colsample_bytree  subsample  nrounds  ROC      
##   0.3  1          0.6               0.50        50      0.8242805
##   0.3  1          0.6               0.50       100      0.8045003
##   0.3  1          0.6               0.50       150      0.7812140
##   0.3  1          0.6               0.75        50      0.8425955
##   0.3  1          0.6               0.75       100      0.8142857
##   0.3  1          0.6               0.75       150      0.8030874
##   0.3  1          0.6               1.00        50      0.8479069
##   0.3  1          0.6               1.00       100      0.8358451
##   0.3  1          0.6               1.00       150      0.8219519
##   0.3  1          0.8               0.50        50      0.8185243
##   0.3  1          0.8               0.50       100      0.8071167
##   0.3  1          0.8               0.50       150      0.7985871
##   0.3  1          0.8               0.75        50      0.8560963
##   0.3  1          0.8               0.75       100      0.8390895
##   0.3  1          0.8               0.75       150      0.8167452
##   0.3  1          0.8               1.00        50      0.8281266
##   0.3  1          0.8               1.00       100      0.8308739
##   0.3  1          0.8               1.00       150      0.8188121
##   0.3  2          0.6               0.50        50      0.8387755
##   0.3  2          0.6               0.50       100      0.8088959
##   0.3  2          0.6               0.50       150      0.7983778
##   0.3  2          0.6               0.75        50      0.8101518
##   0.3  2          0.6               0.75       100      0.7917844
##   0.3  2          0.6               0.75       150      0.7882784
##   0.3  2          0.6               1.00        50      0.8414181
##   0.3  2          0.6               1.00       100      0.8275249
##   0.3  2          0.6               1.00       150      0.8231293
##   0.3  2          0.8               0.50        50      0.8316065
##   0.3  2          0.8               0.50       100      0.8063318
##   0.3  2          0.8               0.50       150      0.7815803
##   0.3  2          0.8               0.75        50      0.8087389
##   0.3  2          0.8               0.75       100      0.8024594
##   0.3  2          0.8               0.75       150      0.8023025
##   0.3  2          0.8               1.00        50      0.8226060
##   0.3  2          0.8               1.00       100      0.8286761
##   0.3  2          0.8               1.00       150      0.8253794
##   0.3  3          0.6               0.50        50      0.7900576
##   0.3  3          0.6               0.50       100      0.7731031
##   0.3  3          0.6               0.50       150      0.7699110
##   0.3  3          0.6               0.75        50      0.8150183
##   0.3  3          0.6               0.75       100      0.7938252
##   0.3  3          0.6               0.75       150      0.7848770
##   0.3  3          0.6               1.00        50      0.8164312
##   0.3  3          0.6               1.00       100      0.8116693
##   0.3  3          0.6               1.00       150      0.8052852
##   0.3  3          0.8               0.50        50      0.8193616
##   0.3  3          0.8               0.50       100      0.8098378
##   0.3  3          0.8               0.50       150      0.8040816
##   0.3  3          0.8               0.75        50      0.8170068
##   0.3  3          0.8               0.75       100      0.8184197
##   0.3  3          0.8               0.75       150      0.8046572
##   0.3  3          0.8               1.00        50      0.8100471
##   0.3  3          0.8               1.00       100      0.8255887
##   0.3  3          0.8               1.00       150      0.8400837
##   0.4  1          0.6               0.50        50      0.8020931
##   0.4  1          0.6               0.50       100      0.7915751
##   0.4  1          0.6               0.50       150      0.7383046
##   0.4  1          0.6               0.75        50      0.8138409
##   0.4  1          0.6               0.75       100      0.8055992
##   0.4  1          0.6               0.75       150      0.8024594
##   0.4  1          0.6               1.00        50      0.8347462
##   0.4  1          0.6               1.00       100      0.8199110
##   0.4  1          0.6               1.00       150      0.8177132
##   0.4  1          0.8               0.50        50      0.8203558
##   0.4  1          0.8               0.50       100      0.7848770
##   0.4  1          0.8               0.50       150      0.7832548
##   0.4  1          0.8               0.75        50      0.8227106
##   0.4  1          0.8               0.75       100      0.8123496
##   0.4  1          0.8               0.75       150      0.8081109
##   0.4  1          0.8               1.00        50      0.8418891
##   0.4  1          0.8               1.00       100      0.8322083
##   0.4  1          0.8               1.00       150      0.8234171
##   0.4  2          0.6               0.50        50      0.8373626
##   0.4  2          0.6               0.50       100      0.8287284
##   0.4  2          0.6               0.50       150      0.8087389
##   0.4  2          0.6               0.75        50      0.8193616
##   0.4  2          0.6               0.75       100      0.7883307
##   0.4  2          0.6               0.75       150      0.7833595
##   0.4  2          0.6               1.00        50      0.8145997
##   0.4  2          0.6               1.00       100      0.8107274
##   0.4  2          0.6               1.00       150      0.7999477
##   0.4  2          0.8               0.50        50      0.8242805
##   0.4  2          0.8               0.50       100      0.7828362
##   0.4  2          0.8               0.50       150      0.7985348
##   0.4  2          0.8               0.75        50      0.8182104
##   0.4  2          0.8               0.75       100      0.7960230
##   0.4  2          0.8               0.75       150      0.7897959
##   0.4  2          0.8               1.00        50      0.8442700
##   0.4  2          0.8               1.00       100      0.8209838
##   0.4  2          0.8               1.00       150      0.7978022
##   0.4  3          0.6               0.50        50      0.8087912
##   0.4  3          0.6               0.50       100      0.7755102
##   0.4  3          0.6               0.50       150      0.7758765
##   0.4  3          0.6               0.75        50      0.8237572
##   0.4  3          0.6               0.75       100      0.7975929
##   0.4  3          0.6               0.75       150      0.7891680
##   0.4  3          0.6               1.00        50      0.8068027
##   0.4  3          0.6               1.00       100      0.7957091
##   0.4  3          0.6               1.00       150      0.7724751
##   0.4  3          0.8               0.50        50      0.8177394
##   0.4  3          0.8               0.50       100      0.8113030
##   0.4  3          0.8               0.50       150      0.7859759
##   0.4  3          0.8               0.75        50      0.7978022
##   0.4  3          0.8               0.75       100      0.7801675
##   0.4  3          0.8               0.75       150      0.7616431
##   0.4  3          0.8               1.00        50      0.8140241
##   0.4  3          0.8               1.00       100      0.7995290
##   0.4  3          0.8               1.00       150      0.8015175
##   Sens       Spec     
##   0.8175824  0.6714286
##   0.8021978  0.6428571
##   0.7582418  0.5285714
##   0.8791209  0.6428571
##   0.8483516  0.6142857
##   0.8186813  0.6428571
##   0.8186813  0.6142857
##   0.8043956  0.6428571
##   0.7890110  0.6142857
##   0.8175824  0.6714286
##   0.8032967  0.6428571
##   0.7725275  0.6142857
##   0.8175824  0.6714286
##   0.7868132  0.6714286
##   0.8032967  0.6142857
##   0.8186813  0.5857143
##   0.8043956  0.6428571
##   0.8043956  0.6428571
##   0.8516484  0.5809524
##   0.8351648  0.6142857
##   0.8054945  0.6142857
##   0.8494505  0.6428571
##   0.8637363  0.6714286
##   0.8494505  0.6428571
##   0.8208791  0.6714286
##   0.8340659  0.6428571
##   0.8340659  0.6428571
##   0.8186813  0.6428571
##   0.8186813  0.5571429
##   0.8186813  0.5571429
##   0.8186813  0.6714286
##   0.8340659  0.5857143
##   0.8494505  0.6142857
##   0.8197802  0.6714286
##   0.8186813  0.6714286
##   0.8186813  0.6428571
##   0.8197802  0.5285714
##   0.8197802  0.5285714
##   0.8043956  0.5857143
##   0.8186813  0.6142857
##   0.7901099  0.5571429
##   0.7890110  0.5571429
##   0.8329670  0.6714286
##   0.8329670  0.7000000
##   0.8186813  0.6428571
##   0.8780220  0.5857143
##   0.8637363  0.6142857
##   0.8351648  0.5857143
##   0.8032967  0.7285714
##   0.8329670  0.6714286
##   0.8329670  0.6714286
##   0.8483516  0.6714286
##   0.8329670  0.6428571
##   0.8329670  0.6714286
##   0.8186813  0.6428571
##   0.8186813  0.6428571
##   0.7879121  0.5285714
##   0.8186813  0.6142857
##   0.7879121  0.6428571
##   0.8186813  0.6142857
##   0.7890110  0.5571429
##   0.7890110  0.6142857
##   0.7736264  0.6142857
##   0.7890110  0.5857143
##   0.8186813  0.5857143
##   0.7890110  0.6142857
##   0.8032967  0.6428571
##   0.8032967  0.6142857
##   0.8032967  0.5571429
##   0.8043956  0.6428571
##   0.8043956  0.6428571
##   0.8043956  0.6428571
##   0.8186813  0.6714286
##   0.8637363  0.6428571
##   0.8032967  0.6428571
##   0.8494505  0.5857143
##   0.8329670  0.5857143
##   0.8186813  0.5857143
##   0.8175824  0.6142857
##   0.8340659  0.6142857
##   0.8494505  0.6142857
##   0.8340659  0.5571429
##   0.7725275  0.5571429
##   0.7736264  0.5857143
##   0.8186813  0.6142857
##   0.8186813  0.5571429
##   0.8032967  0.5571429
##   0.8186813  0.6142857
##   0.8032967  0.6428571
##   0.8032967  0.6428571
##   0.7879121  0.6714286
##   0.8032967  0.5571429
##   0.8186813  0.5285714
##   0.8032967  0.7000000
##   0.8494505  0.6142857
##   0.8637363  0.5857143
##   0.8043956  0.6714286
##   0.8054945  0.6428571
##   0.8054945  0.6428571
##   0.7901099  0.6428571
##   0.8197802  0.6428571
##   0.8043956  0.5571429
##   0.8197802  0.5857143
##   0.8197802  0.5857143
##   0.8648352  0.6142857
##   0.8186813  0.7285714
##   0.8175824  0.6714286
##   0.8175824  0.6714286
## 
## Tuning parameter 'gamma' was held constant at a value of 0
## 
## Tuning parameter 'min_child_weight' was held constant at a value of 1
## ROC was used to select the optimal model using  the largest value.
## The final values used for the model were nrounds = 50, max_depth = 1,
##  eta = 0.3, gamma = 0, colsample_bytree = 0.8, min_child_weight = 1
##  and subsample = 0.75.

Session Info:

sessionInfo()
## R version 3.3.2 (2016-10-31)
## Platform: x86_64-w64-mingw32/x64 (64-bit)
## Running under: Windows 10 x64 (build 10586)
## 
## locale:
## [1] LC_COLLATE=English_United Kingdom.1252 
## [2] LC_CTYPE=English_United Kingdom.1252   
## [3] LC_MONETARY=English_United Kingdom.1252
## [4] LC_NUMERIC=C                           
## [5] LC_TIME=English_United Kingdom.1252    
## 
## attached base packages:
## [1] stats     graphics  grDevices utils     datasets  methods   base     
## 
## other attached packages:
## [1] plyr_1.8.4           xgboost_0.4-4        caret_6.0-73        
## [4] ggplot2_2.1.0        lattice_0.20-34      RevoUtilsMath_10.0.0
## 
## loaded via a namespace (and not attached):
##  [1] Rcpp_0.12.7        compiler_3.3.2     formatR_1.4       
##  [4] nloptr_1.0.4       iterators_1.0.8    tools_3.3.2       
##  [7] digest_0.6.10      lme4_1.1-12        evaluate_0.10     
## [10] nlme_3.1-128       gtable_0.2.0       mgcv_1.8-15       
## [13] Matrix_1.2-7.1     foreach_1.4.3      yaml_2.1.13       
## [16] parallel_3.3.2     SparseM_1.72       stringr_1.1.0     
## [19] knitr_1.14         RevoUtils_10.0.2   MatrixModels_0.4-1
## [22] stats4_3.3.2       rprojroot_1.1      grid_3.3.2        
## [25] nnet_7.3-12        data.table_1.9.6   rmarkdown_1.3     
## [28] minqa_1.2.4        reshape2_1.4.2     car_2.1-3         
## [31] magrittr_1.5       backports_1.0.4    scales_0.4.0      
## [34] codetools_0.2-15   ModelMetrics_1.1.0 htmltools_0.3.5   
## [37] MASS_7.3-45        splines_3.3.2      pbkrtest_0.4-6    
## [40] colorspace_1.2-7   quantreg_5.29      stringi_1.1.2     
## [43] munsell_0.4.3      chron_2.3-47

Running Dev Caret 6.0-75

Here only the maximum depth is varied from 1-3:

library(caret)

set.seed(1)
dat <- twoClassSim(100)
X <- dat[,1:5]
y <- dat[["Class"]]

model_class <- train(
  X, y, method='xgbTree',
  metric='ROC',
  trControl=trainControl(
    method="cv", 
    number=5,
    classProbs=TRUE, 
    summaryFunction=twoClassSummary)
)

print(model_class)
## eXtreme Gradient Boosting 
## 
## 100 samples
##   5 predictor
##   2 classes: 'Class1', 'Class2' 
## 
## No pre-processing
## Resampling: Cross-Validated (5 fold) 
## Summary of sample sizes: 80, 80, 81, 80, 79 
## Resampling results across tuning parameters:
## 
##   max_depth  ROC        Sens       Spec     
##   1          0.7954474  0.8186813  0.5857143
##   2          0.7889587  0.8186813  0.6142857
##   3          0.7984825  0.7736264  0.6380952
## 
## Tuning parameter 'nrounds' was held constant at a value of 50
##  0.6
## Tuning parameter 'min_child_weight' was held constant at a value of
##  1
## Tuning parameter 'subsample' was held constant at a value of 0.5
## ROC was used to select the optimal model using  the largest value.
## The final values used for the model were nrounds = 50, max_depth = 3,
##  eta = 0.3, gamma = 0, colsample_bytree = 0.6, min_child_weight = 1
##  and subsample = 0.5.

Session Info:

sessionInfo()
## R version 3.3.2 (2016-10-31)
## Platform: x86_64-w64-mingw32/x64 (64-bit)
## Running under: Windows 10 x64 (build 10586)
## 
## locale:
## [1] LC_COLLATE=English_United Kingdom.1252 
## [2] LC_CTYPE=English_United Kingdom.1252   
## [3] LC_MONETARY=English_United Kingdom.1252
## [4] LC_NUMERIC=C                           
## [5] LC_TIME=English_United Kingdom.1252    
## 
## attached base packages:
## [1] stats     graphics  grDevices utils     datasets  methods   base     
## 
## other attached packages:
## [1] caret_6.0-75         plyr_1.8.4           xgboost_0.4-4       
## [4] ggplot2_2.1.0        lattice_0.20-34      RevoUtilsMath_10.0.0
## 
## loaded via a namespace (and not attached):
##  [1] Rcpp_0.12.7        git2r_0.15.0       compiler_3.3.2    
##  [4] formatR_1.4        nloptr_1.0.4       iterators_1.0.8   
##  [7] tools_3.3.2        digest_0.6.10      lme4_1.1-12       
## [10] memoise_1.0.0      evaluate_0.10      nlme_3.1-128      
## [13] gtable_0.2.0       mgcv_1.8-15        Matrix_1.2-7.1    
## [16] foreach_1.4.3      curl_2.2           yaml_2.1.13       
## [19] parallel_3.3.2     SparseM_1.72       httr_1.2.1        
## [22] withr_1.0.2        stringr_1.1.0      knitr_1.14        
## [25] devtools_1.12.0    RevoUtils_10.0.2   MatrixModels_0.4-1
## [28] stats4_3.3.2       rprojroot_1.1      grid_3.3.2        
## [31] nnet_7.3-12        data.table_1.9.6   R6_2.2.0          
## [34] rmarkdown_1.3      minqa_1.2.4        reshape2_1.4.2    
## [37] car_2.1-3          magrittr_1.5       backports_1.0.4   
## [40] scales_0.4.0       codetools_0.2-15   ModelMetrics_1.1.0
## [43] htmltools_0.3.5    MASS_7.3-45        splines_3.3.2     
## [46] pbkrtest_0.4-6     colorspace_1.2-7   quantreg_5.29     
## [49] stringi_1.1.2      munsell_0.4.3      chron_2.3-47
topepo commented 7 years ago

That is a bug induced here. I'll fix it now.

topepo commented 7 years ago

That should work for you but let me know if there are still issues.

I haven't run the regression test suite in a while and the bug would have been caught there.

Thanks

RobertKomara commented 7 years ago

I confirm that fixes it, Thank you!