topepo / caret

caret (Classification And Regression Training) R package that contains misc functions for training and plotting classification and regression models
http://topepo.github.io/caret/index.html
1.62k stars 632 forks source link

createFolds is very slow when y is a character with many values #467

Open thvasilo opened 8 years ago

thvasilo commented 8 years ago

I often use the createFolds and createDataPartition functions to create samples of my data stratified by subject id, which I store as a character variable in my dataframe.

However, this leads to very slow creation of partitionings. Take for example the following code sample:

Minimal dataset:

library(caret)
set.seed(1)
dat <- twoClassSim(1e6)

ids <- rep(1:1e4,each=100)
dat[,"id"] <- as.character(ids)

Minimal, runnable code:

system.time(dbl_folds <- createFolds(dat$TwoFactor1, returnTrain = TRUE))
>    user  system elapsed 
>  0.497   0.107   0.604 

system.time(id_folds <- createFolds(dat$id, returnTrain = TRUE))
>    user  system elapsed 
> 262.897  87.961 351.007

Session Info:

>sessionInfo()
R version 3.3.1 (2016-06-21)
Platform: x86_64-apple-darwin13.4.0 (64-bit)
Running under: OS X 10.11.5 (El Capitan)

locale:
[1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8

attached base packages:
[1] parallel  stats     graphics  grDevices utils     datasets  methods  
[8] base     

other attached packages:
[1] caret_6.0-70     lattice_0.20-33  dplyr_0.5.0      tidyr_0.5.1     
[5] MASS_7.3-45      scales_0.4.0     data.table_1.9.6 ggplot2_2.1.0   

loaded via a namespace (and not attached):
 [1] Rcpp_0.12.5        nloptr_1.0.4       plyr_1.8.4        
 [4] iterators_1.0.8    tools_3.3.1        digest_0.6.9      
 [7] lme4_1.1-12        packrat_0.4.7-1    tibble_1.1        
[10] gtable_0.2.0       nlme_3.1-128       mgcv_1.8-12       
[13] Matrix_1.2-6       foreach_1.4.3      DBI_0.4-1         
[16] SparseM_1.7        stringr_1.0.0      MatrixModels_0.4-1
[19] stats4_3.3.1       grid_3.3.1         nnet_7.3-12       
[22] R6_2.1.2           survival_2.39-4    minqa_1.2.4       
[25] reshape2_1.4.1     car_2.1-2          magrittr_1.5      
[28] codetools_0.2-14   splines_3.3.1      assertthat_0.1    
[31] pbkrtest_0.4-6     colorspace_1.2-6   quantreg_5.26     
[34] labeling_0.3       stringi_1.1.1      lazyeval_0.2.0    
[37] munsell_0.4.3      chron_2.3-47 

Any idea why there is such a huge difference in the time it takes to create the folds? Do I have alternatives to create samples stratified by ID using caret?

topepo commented 8 years ago

It is basically running a for loop over the strata/class variable levels.

Typically, you would want to use the outcome that you are modeling as the input into the data partitioning functions. If you have an ID variable, using this will create the modeling/holdout partitions within that ID variable and you may not want that (since the same ID will be inside both partitions).

There have been a few requests (including one the other day -- issue #465) to have a function that would do the fold generation such that the data corresponding to some clustering/ID/grouping variable are always in the same fold. I'm working on that now and it may be more of what you want. I'll try to keep away from for loops this time =]

thvasilo commented 8 years ago

Yes, in my case what I'm trying to achieve is that the same subjects are in the train and test set for each fold, so yes in my case it is within the variable.

As small test that I use to ensure this is true after the split is this one:

library(purrr)
library(data.table)

setDT(sample)
train_folds <- createFolds(sample$id, k = 10, returnTrain = TRUE)

map(train_folds, ~ assertthat::assert_that(uniqueN(sample[.,id]) == uniqueN(sample[!.,id])))

Note that this tests for the number of unique ids, I should probably change it to check for exactly matching ids instead.

So this goes in the opposite direction of #465, where we want different ids in the train and test set.

topepo commented 8 years ago
> library(purrr)
> library(data.table)
> 
> setDT(sample)
Error in setDT(sample) : 
  Can not convert 'sample' to data.table by reference because binding is locked. It is very likely that 'sample' resides within a package (or an environment) that is locked to prevent modifying its variable bindings. Try copying the object to your current environment, ex: var <- copy(var) and then using setDT again.
> sessionInfo()
R version 3.2.3 (2015-12-10)
Platform: x86_64-redhat-linux-gnu (64-bit)
Running under: Red Hat Enterprise Linux Server release 6.7 (Santiago)

locale:
 [1] LC_CTYPE=en_US.UTF-8       LC_NUMERIC=C              
 [3] LC_TIME=en_US.UTF-8        LC_COLLATE=en_US.UTF-8    
 [5] LC_MONETARY=en_US.UTF-8    LC_MESSAGES=en_US.UTF-8   
 [7] LC_PAPER=en_US.UTF-8       LC_NAME=C                 
 [9] LC_ADDRESS=C               LC_TELEPHONE=C            
[11] LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C       

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
[1] data.table_1.9.6     purrr_0.2.2          XLConnect_0.2-12    
[4] XLConnectJars_0.2-12

loaded via a namespace (and not attached):
[1] htmlwidgets_0.6 magrittr_1.5    htmltools_0.3   tools_3.2.3    
[5] Rcpp_0.12.3     DT_0.1          digest_0.6.8    chron_2.3-47   
[9] rJava_0.9-8 
thvasilo commented 8 years ago

Yeah, the above was not a complete code sample ( sample is supposed to be a data.table)

Updated code sample:

library(purrr)
library(data.table)
library(caret)
set.seed(1)

dat <- twoClassSim(1e6)
ids <- rep(1:1e4,each=100)
dat[,"id"] <- as.character(ids)
setDT(dat)

id_folds <- createFolds(dat$id, returnTrain = TRUE)

map(id_folds, ~ assertthat::assert_that(uniqueN(dat[.,id]) == uniqueN(dat[!.,id])))

Note that depending on the number of records for each subject and the percentage of the train/test split the test could pass or fail, i.e. this is not a full-proof method which is why I would like to have something that does some sanity checking to ensure that even subjects with much fewer records will end up correctly (or at least have some warning).

e.g. if a subject only has 3 records and we try to do a 90/10% split I think it will fail.

yitang commented 8 years ago

I am facing the same performance problem. I am thinking of using parallel to speed up but not sure if it's doable and how to make it reproducible...