topepo / caret

caret (Classification And Regression Training) R package that contains misc functions for training and plotting classification and regression models
http://topepo.github.io/caret/index.html
1.61k stars 634 forks source link

Downsampling during cv not returning downsampled folds. #1313

Open nk-at-essentia opened 1 year ago

nk-at-essentia commented 1 year ago

I have a class imbalance issue with my data, and following the guide here, I'm trying to downsample during a cv. I'd like to be able to inspect individual folds after running the model, but when I attempt to pull an individual fold from the training set I don't see balanced classes like I expected. set.seed(2969) imbal_train <- twoClassSim(10000, intercept = -20, linearVars = 20) tr_ctrl <- trainControl(method = "cv", number = 5, classProbs = TRUE, p=0.5, summaryFunction = twoClassSummary, sampling = "down") testModel<-train(Class ~ ., data = imbal_train, method = "rf", metric = "ROC", trControl = tr_ctrl) fold1<-imbal_train[testModel$control$index$Fold1,] table(fold1$Class)

Am I misunderstanding and the downsampling occurs after the creation of the fold? If so, how would I retrieve those indices?

R Session Info: `R version 4.2.1 (2022-06-23 ucrt) Platform: x86_64-w64-mingw32/x64 (64-bit) Running under: Windows 10 x64 (build 19044)

Matrix products: default

locale: [1] LC_COLLATE=English_United States.utf8 LC_CTYPE=English_United States.utf8
[3] LC_MONETARY=English_United States.utf8 LC_NUMERIC=C
[5] LC_TIME=English_United States.utf8

attached base packages: [1] stats graphics grDevices utils datasets methods base

other attached packages: [1] caret_6.0-93 lattice_0.20-45 ggplot2_3.3.6

loaded via a namespace (and not attached): [1] tidyselect_1.1.2 purrr_0.3.4 reshape2_1.4.4 listenv_0.8.0 splines_4.2.1
[6] colorspace_2.0-3 vctrs_0.4.1 generics_0.1.3 stats4_4.2.1 utf8_1.2.2
[11] survival_3.3-1 prodlim_2019.11.13 rlang_1.0.4 ModelMetrics_1.2.2.2 pillar_1.8.0
[16] glue_1.6.2 withr_2.5.0 DBI_1.1.3 foreach_1.5.2 lifecycle_1.0.1
[21] plyr_1.8.7 lava_1.6.10 stringr_1.4.0 timeDate_4021.104 munsell_0.5.0
[26] gtable_0.3.0 future_1.27.0 recipes_1.0.1 codetools_0.2-18 parallel_4.2.1
[31] class_7.3-20 fansi_1.0.3 Rcpp_1.0.9 scales_1.2.0 ipred_0.9-13
[36] parallelly_1.32.1 digest_0.6.29 stringi_1.7.8 dplyr_1.0.9 grid_4.2.1
[41] hardhat_1.2.0 cli_3.3.0 tools_4.2.1 magrittr_2.0.3 tibble_3.1.8
[46] randomForest_4.7-1.1 future.apply_1.9.0 pkgconfig_2.0.3 MASS_7.3-57 Matrix_1.4-1
[51] data.table_1.14.2 pROC_1.18.0 lubridate_1.8.0 gower_1.0.0 assertthat_0.2.1
[56] rstudioapi_0.13 iterators_1.0.14 R6_2.5.1 globals_0.16.0 rpart_4.1.16
[61] nnet_7.3-17 nlme_3.1-157 compiler_4.2.1 `