topepo / caret

caret (Classification And Regression Training) R package that contains misc functions for training and plotting classification and regression models
http://topepo.github.io/caret/index.html
1.61k stars 634 forks source link

bug in createFolds where it fails when y is numeric but all values of y are the same #1357

Open mikeblazanin opened 6 months ago

mikeblazanin commented 6 months ago

Minimal, reproducible example:

custommethod <- 
  list(library = NULL, type = "Regression", prob = NULL,
       fit = function(x, y, wts, param, lev = NULL, last, weights, classprobs, ...) {return(stats::runmed(x = y, k = param$k))},
       parameters = data.frame(parameter = "k", class = "numeric", label = "k"),
       grid = function(x, y, len, search) {return(data.frame(k = seq(from = 1, by = 2, length.out = len)))},
       predict = function(modelFit, newdata, preProc = NULL, submodels = NULL) {return(rep(NA, length(newdata)))})

library(caret)
set.seed(1)
data <- data.frame(x = 1:100, y = rep(0.05, 100))
train(x = data.frame(x = data$x), y = data$y, method = custommethod, trControl = trainControl(method = "cv"))

Error in cut.default(y, breaks, include.lowest = TRUE) : 
  invalid number of intervals

It appears this issue is because of this section of createFolds

if(is.numeric(y)) {
      cuts <- floor(length(y)/k)
      if(cuts < 2) cuts <- 2
      if(cuts > 5) cuts <- 5
      breaks <- unique(quantile(y, probs = seq(0, 1, length = cuts)))
      y <- cut(y, breaks, include.lowest = TRUE)
    }

When y is numeric, but has no variation, breaks will be the single value of y. However, cut will interpret this single value not as the value where a cut should be made, but as the number of breaks to make. When the y value is an integer, this will likely lead to an unexpected result where the number of breaks made is the y value, rather than the value of cuts. When the y value is not an integer, this will return an error because breaks can only be an integer.

Session Info:

>sessionInfo()
> sessionInfo()
R version 4.3.2 (2023-10-31 ucrt)
Platform: x86_64-w64-mingw32/x64 (64-bit)
Running under: Windows 11 x64 (build 22621)

Matrix products: default

locale:
[1] LC_COLLATE=English_United States.utf8  LC_CTYPE=English_United States.utf8   
[3] LC_MONETARY=English_United States.utf8 LC_NUMERIC=C                          
[5] LC_TIME=English_United States.utf8    

time zone: America/New_York
tzcode source: internal

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
[1] caret_6.0-94   lattice_0.21-9 ggplot2_3.4.4 

loaded via a namespace (and not attached):
 [1] future_1.33.1        utf8_1.2.4           generics_0.1.3       class_7.3-22        
 [5] stringi_1.8.3        pROC_1.18.5          listenv_0.9.1        digest_0.6.34       
 [9] magrittr_2.0.3       timechange_0.3.0     evaluate_0.23        grid_4.3.2          
[13] iterators_1.0.14     fastmap_1.1.1        foreach_1.5.2        plyr_1.8.9          
[17] Matrix_1.6-1.1       ModelMetrics_1.2.2.2 nnet_7.3-19          survival_3.5-7      
[21] purrr_1.0.2          fansi_1.0.6          scales_1.3.0         codetools_0.2-19    
[25] lava_1.7.3           cli_3.6.2            rlang_1.1.3          hardhat_1.3.1       
[29] parallelly_1.37.1    future.apply_1.11.1  munsell_0.5.0        splines_4.3.2       
[33] withr_3.0.0          yaml_2.3.8           prodlim_2023.08.28   parallel_4.3.2      
[37] tools_4.3.2          reshape2_1.4.4       dplyr_1.1.4          colorspace_2.1-0    
[41] recipes_1.0.10       globals_0.16.2       vctrs_0.6.5          R6_2.5.1            
[45] rpart_4.1.21         stats4_4.3.2         lubridate_1.9.3      lifecycle_1.0.4     
[49] stringr_1.5.1        MASS_7.3-60          pkgconfig_2.0.3      pillar_1.9.0        
[53] gtable_0.3.4         glue_1.7.0           data.table_1.15.0    Rcpp_1.0.12         
[57] xfun_0.41            tibble_3.2.1         tidyselect_1.2.0     rstudioapi_0.15.0   
[61] knitr_1.45           htmltools_0.5.7      nlme_3.1-163         rmarkdown_2.25      
[65] ipred_0.9-14         timeDate_4032.109    gower_1.0.1          compiler_4.3.2