tidymodels / rsample

Classes and functions to create and summarize resampling objects
https://rsample.tidymodels.org
Other
340 stars 67 forks source link

stratified sampling does not #438

Closed afrogri37 closed 1 year ago

afrogri37 commented 1 year ago

The problem

I am trying to create cross-validation folds in an imbalanced dataset with a label of two classes ("occur"), in which the class "1" has 7 observations and the class "0" has 10,000. I am using the option strata=occur in the vfold_cv() function so that the class "1" is represented in both the analysis and assessment groups of the folds. However this is not the case, so in the end my model's metrics can't be calculated for every fold. Is this a bug in the vfold_cv function , or is there another option in the resampling process that I can use to obtain stratified samples? Thank you!

Reproducible example

library(tidymodels)
library(ranger)
library(data.table)
set.seed(123)
data <- data.frame(occur = c(rep(1, 7), rep(0, 10000)),
                   var1 = c(seq(1, 7), seq(1:10000)),
                   var2 = c(seq(2, 8), seq(2:10001)))
data$occur <- as.factor(data$occur)

# Split into train and test sets

preds_split <- initial_split(data, strata = occur)
preds_train <- training(preds_split)
preds_test  <- testing(preds_split)

# Split in folds
folds <- vfold_cv(preds_train, v = 10, strata = occur, rep = 1)

# Check the assessment set of one fold

> folds$splits[[2]] %>% assessment()
#      occur var1 var1.1
#   1:     0    4      4
#   2:     0    6      6
#   3:     0   26     26
#   4:     0   29     29
#   5:     0   93     93
#  ---                  
# 747:     0 9973   9973
# 748:     0 9980   9980
# 749:     0 9981   9981
# 750:     0 9982   9982
# 751:     0 9996   9996

# define recipe
recipe <- recipe(occur ~ ., data = data)

# ranger model
rf_mod <- rand_forest(trees = 1000) %>%
  set_engine("ranger",
             importance = "impurity",
             replace = TRUE,
             oob.error = TRUE,
             keep.inbag = TRUE,
             num.threads = 50,
             probability = TRUE) %>%
  set_mode("classification")

# Workflow
rf_wf <-
  workflow() %>%
  add_recipe(recipe)  %>%
  add_model(rf_mod)

# Fit model
rf_fit <-
  rf_wf %>%
  fit_resamples(folds,
                metrics = metric_set( accuracy, kap,
                                      precision,  f_meas, sens, spec, j_index),
                control = control_resamples(save_pred = TRUE))

# warning: While computing binary `spec()`, no true negatives were detected (i.e. `true_negative + false_positive = 0`). 
#             Specificity is undefined in this case, and `NA` will be returned.
#             Note that 1 predicted negatives(s) actually occured for the problematic event level, '0'., No control observations were detected in `truth` with control level '1'.

# Metrics
metrics <- collect_metrics(rf_fit) 
metrics
# # A tibble: 7 × 6
#   .metric   .estimator      mean     n  std_err .config             
#   <chr>     <chr>          <dbl> <int>    <dbl> <chr>               
# 1 accuracy  binary      0.998       10 0.000478 Preprocessor1_Model1
# 2 f_meas    binary      0.999       10 0.000240 Preprocessor1_Model1
# 3 j_index   binary     -0.000446     3 0.000446 Preprocessor1_Model1
# 4 kap       binary     -0.000254     7 0.000254 Preprocessor1_Model1
# 5 precision binary      0.999       10 0.000453 Preprocessor1_Model1
# 6 sens      binary      0.999       10 0.000295 Preprocessor1_Model1
# 7 spec      binary      0            3 0        Preprocessor1_Model1

Created on 2023-07-07 with reprex v2.0.2

Session info ``` r sessioninfo::session_info() ─ Session info ─────────────────────────────────────────────────────────────── setting value version R version 4.3.0 (2023-04-21) os Linux Mint 20 system x86_64, linux-gnu ui X11 language (EN) collate en_US.UTF-8 ctype en_US.UTF-8 tz Europe/Berlin date 2023-07-07 ─ Packages ─────────────────────────────────────────────────────────────────── package * version date lib source backports 1.2.1 2020-12-09 [1] CRAN (R 4.0.5) broom * 1.0.5 2023-06-09 [1] CRAN (R 4.3.0) class 7.3-21 2023-01-23 [3] CRAN (R 4.2.2) cli 3.6.1 2023-03-23 [1] CRAN (R 4.3.0) codetools 0.2-19 2023-02-01 [3] CRAN (R 4.2.2) colorspace 2.1-0 2023-01-23 [1] CRAN (R 4.3.0) data.table * 1.14.8 2023-02-17 [1] CRAN (R 4.3.0) dials * 1.2.0 2023-04-03 [1] CRAN (R 4.3.0) DiceDesign 1.9 2021-02-13 [1] CRAN (R 4.3.0) digest 0.6.31 2022-12-11 [1] CRAN (R 4.3.0) dplyr * 1.1.2 2023-04-20 [1] CRAN (R 4.3.0) ellipsis 0.3.2 2021-04-29 [1] CRAN (R 4.0.5) fansi 1.0.4 2023-01-22 [1] CRAN (R 4.3.0) foreach 1.5.2 2022-02-02 [1] CRAN (R 4.3.0) furrr 0.3.1 2022-08-15 [1] CRAN (R 4.3.0) future 1.32.0 2023-03-07 [1] CRAN (R 4.3.0) generics 0.1.3 2022-07-05 [1] CRAN (R 4.3.0) ggplot2 * 3.4.2 2023-04-03 [1] CRAN (R 4.3.0) globals 0.16.2 2022-11-21 [1] CRAN (R 4.3.0) glue 1.6.2 2022-02-24 [1] CRAN (R 4.3.0) gower 0.2.2 2020-06-23 [1] CRAN (R 4.0.5) GPfit 1.0-8 2019-02-08 [1] CRAN (R 4.3.0) gtable 0.3.3 2023-03-21 [1] CRAN (R 4.3.0) hardhat 1.3.0 2023-03-30 [1] CRAN (R 4.3.0) infer * 1.0.4 2022-12-02 [1] CRAN (R 4.3.0) ipred 0.9-14 2023-03-09 [1] CRAN (R 4.3.0) iterators 1.0.14 2022-02-05 [1] CRAN (R 4.3.0) lattice 0.21-8 2023-04-05 [3] CRAN (R 4.3.0) lava 1.6.9 2021-03-11 [1] CRAN (R 4.0.5) lhs 1.1.6 2022-12-17 [1] CRAN (R 4.3.0) lifecycle 1.0.3 2022-10-07 [1] CRAN (R 4.3.0) listenv 0.9.0 2022-12-16 [1] CRAN (R 4.3.0) lubridate 1.9.2 2023-02-10 [1] CRAN (R 4.3.0) magrittr 2.0.3 2022-03-30 [1] CRAN (R 4.3.0) MASS 7.3-59 2023-04-21 [3] CRAN (R 4.3.0) Matrix 1.5-4.1 2023-05-18 [3] CRAN (R 4.3.0) modeldata * 1.1.0 2023-01-25 [1] CRAN (R 4.3.0) munsell 0.5.0 2018-06-12 [1] CRAN (R 4.0.5) nnet 7.3-18 2022-09-28 [3] CRAN (R 4.2.1) parallelly 1.36.0 2023-05-26 [1] CRAN (R 4.3.0) parsnip * 1.1.0 2023-04-12 [1] CRAN (R 4.3.0) pillar 1.9.0 2023-03-22 [1] CRAN (R 4.3.0) pkgconfig 2.0.3 2019-09-22 [1] CRAN (R 4.0.5) prodlim 2019.11.13 2019-11-17 [1] CRAN (R 4.0.5) purrr * 1.0.1 2023-01-10 [1] CRAN (R 4.3.0) R6 2.5.1 2021-08-19 [1] CRAN (R 4.3.0) ranger * 0.12.1 2020-01-10 [1] CRAN (R 4.0.5) Rcpp 1.0.10 2023-01-22 [1] CRAN (R 4.3.0) recipes * 1.0.6 2023-04-25 [1] CRAN (R 4.3.0) rlang 1.1.1 2023-04-28 [1] CRAN (R 4.3.0) rpart 4.1.19 2022-10-21 [3] CRAN (R 4.2.1) rsample * 1.1.1 2022-12-07 [1] CRAN (R 4.3.0) rstudioapi 0.14 2022-08-22 [1] CRAN (R 4.3.0) scales * 1.2.1 2022-08-20 [1] CRAN (R 4.3.0) sessioninfo 1.1.1 2018-11-05 [1] CRAN (R 4.0.5) survival 3.5-5 2023-03-12 [3] CRAN (R 4.2.3) tibble * 3.2.1 2023-03-20 [1] CRAN (R 4.3.0) tidymodels * 1.1.0 2023-05-01 [1] CRAN (R 4.3.0) tidyr * 1.3.0 2023-01-24 [1] CRAN (R 4.3.0) tidyselect 1.2.0 2022-10-10 [1] CRAN (R 4.3.0) timechange 0.2.0 2023-01-11 [1] CRAN (R 4.3.0) timeDate 3043.102 2018-02-21 [1] CRAN (R 4.0.5) tune * 1.1.1 2023-04-11 [1] CRAN (R 4.3.0) utf8 1.2.3 2023-01-31 [1] CRAN (R 4.3.0) vctrs 0.6.3 2023-06-14 [1] CRAN (R 4.3.0) withr 2.5.0 2022-03-03 [1] CRAN (R 4.3.0) workflows * 1.1.3 2023-02-22 [1] CRAN (R 4.3.0) workflowsets * 1.0.1 2023-04-06 [1] CRAN (R 4.3.0) yardstick * 1.2.0 2023-04-21 [1] CRAN (R 4.3.0) [1] /usr/local/lib/R/site-library [2] /usr/lib/R/site-library [3] /usr/lib/R/library ```
mikemahoney218 commented 1 year ago

Hi @afrogri37 !

There's two separate issues here. First, your strata are getting lumped together because you have so few "1" observations. By looking at ?vfold_cv() you can see there's an argument called pool, which is documented as:

A proportion of data used to determine if a particular group is too small and should be pooled into another group. We do not recommend decreasing this argument below its default of 0.1 because of the dangers of stratifying groups that are too small.

So, because 1 is less than 10% of all observations, it's not treated as its own strata. You can see this by running, for instance, make_strata() on your outcome variable:

> make_strata(data$occur) |> unique()
[1] 0
Levels: 0

I added the bold to emphasize that it's often a bad idea to reduce this value. However, if we did, we could see that one "1" observation gets assigned to each assessment set, until you run out:

library(tidymodels)
set.seed(123)
data <- data.frame(occur = c(rep(1, 7), rep(0, 10000)),
                   var1 = c(seq(1, 7), seq(1:10000)),
                   var2 = c(seq(2, 8), seq(2:10001)))
data$occur <- as.factor(data$occur)

preds_split <- initial_split(data, strata = occur)
preds_train <- training(preds_split)
preds_test  <- testing(preds_split)

folds <- vfold_cv(preds_train, v = 10, strata = occur, pool = 0)
#> Warning: Stratifying groups that make up 0% of the data may be statistically risky.
#> • Consider increasing `pool` to at least 0.1
lapply(folds$splits, \(x) assessment(x)$occur |> table())
#> [[1]]
#> 
#>   0   1 
#> 750   1 
#> 
#> [[2]]
#> 
#>   0   1 
#> 750   1 
#> 
#> [[3]]
#> 
#>   0   1 
#> 750   1 
#> 
#> [[4]]
#> 
#>   0   1 
#> 750   1 
#> 
#> [[5]]
#> 
#>   0   1 
#> 750   1 
#> 
#> [[6]]
#> 
#>   0   1 
#> 750   1 
#> 
#> [[7]]
#> 
#>   0   1 
#> 750   0 
#> 
#> [[8]]
#> 
#>   0   1 
#> 750   0 
#> 
#> [[9]]
#> 
#>   0   1 
#> 750   0 
#> 
#> [[10]]
#> 
#>   0   1 
#> 749   0

Created on 2023-07-07 with reprex v2.0.2

That's where the second issue comes in -- V-fold CV splits your data into equal folds, and doesn't otherwise resample or subsample your data. If you're creating 10 folds, and only have 7 observations with a 1 in them, then at most 7 folds can have a 1 in them (in my example above, even fewer do, because one got assigned to either the test or validation set).

Hope that makes sense!

afrogri37 commented 1 year ago

Hi mikemahoney218! This is actually very helpful, thank you!

hfrick commented 1 year ago

@afrogri37 I'm going to close this issue since this isn't a bug in vfold_cv().

For such imbalanced data, you may want to look at other techniques. Some options:

afrogri37 commented 1 year ago

Thank you, I'm considering them for sure!

github-actions[bot] commented 1 year ago

This issue has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue.