tidymodels / rsample

Classes and functions to create and summarize resampling objects
https://rsample.tidymodels.org
Other
338 stars 67 forks source link

group-preserving sampling #24

Closed ClaytonJY closed 6 years ago

ClaytonJY commented 6 years ago

I don't know what to call this, but I'll try to explain my use-case.

I've got a set of data I want to split up for cross-validation (assume v-folding). These observations have a grouping variable, and I want to ensure all groups are kept together, and never split up, when sampling here. Almost an opposite of strata.

As an example, we could use mtcars and the cyl variable; there's 3 unique values (4, 6, 8), so a 3-fold of this type should produce one fold where the assessment is only cyl = 4, another where cyl = 6, etc.

To do that now I have to k-fold on just distinct(mtcars, cyl), and then do something hacky to "expand" those folds.

Would it be possible to combine nested_cv with #23 to achieve this? If not, and this is worthy of inclusion, I'd be happy to help code it up.

Here's my hack:

library(dplyr)
#> 
#> Attaching package: 'dplyr'
#> The following objects are masked from 'package:stats':
#> 
#>     filter, lag
#> The following objects are masked from 'package:base':
#> 
#>     intersect, setdiff, setequal, union
library(purrr)
library(rsample)
#> Loading required package: broom

# suppose we want to keep cylinder-groups together
# we'll vfold those instead of the whole thing
initial_fold <- mtcars %>%
  distinct(cyl) %>%
  vfold_cv(v = 3)

# take a look
initial_fold %>%
  pull(splits) %>%
  map(assessment)
#> $`1`
#>   cyl
#> 2   4
#> 
#> $`2`
#>   cyl
#> 3   8
#> 
#> $`3`
#>   cyl
#> 1   6

# take an existing rsplit object and expand given the original it was subsampled from
expand_split <- function(split, orig) {

  vars <- colnames(split$data)

  rows_in  <- which(pull(orig, vars) %in% pull(analysis(split), vars))
  rows_out <- which(pull(orig, vars) %in% pull(assessment(split), vars))

  rsample:::rsplit(orig, sample(rows_in), sample(rows_out))  # can't forget to shuffle
}

# now apply that to each split
expanded_fold <- initial_fold %>%
  mutate(splits = map(splits, expand_split, mtcars))

# take a look
expanded_fold %>%
  pull(splits) %>%
  map(assessment)
#> $`1`
#>                 mpg cyl  disp  hp drat    wt  qsec vs am gear carb
#> Merc 240D      24.4   4 146.7  62 3.69 3.190 20.00  1  0    4    2
#> Fiat 128       32.4   4  78.7  66 4.08 2.200 19.47  1  1    4    1
#> Datsun 710     22.8   4 108.0  93 3.85 2.320 18.61  1  1    4    1
#> Volvo 142E     21.4   4 121.0 109 4.11 2.780 18.60  1  1    4    2
#> Toyota Corona  21.5   4 120.1  97 3.70 2.465 20.01  1  0    3    1
#> Merc 230       22.8   4 140.8  95 3.92 3.150 22.90  1  0    4    2
#> Fiat X1-9      27.3   4  79.0  66 4.08 1.935 18.90  1  1    4    1
#> Porsche 914-2  26.0   4 120.3  91 4.43 2.140 16.70  0  1    5    2
#> Honda Civic    30.4   4  75.7  52 4.93 1.615 18.52  1  1    4    2
#> Toyota Corolla 33.9   4  71.1  65 4.22 1.835 19.90  1  1    4    1
#> Lotus Europa   30.4   4  95.1 113 3.77 1.513 16.90  1  1    5    2
#> 
#> $`2`
#>                      mpg cyl  disp  hp drat    wt  qsec vs am gear carb
#> Merc 450SL          17.3   8 275.8 180 3.07 3.730 17.60  0  0    3    3
#> Duster 360          14.3   8 360.0 245 3.21 3.570 15.84  0  0    3    4
#> Maserati Bora       15.0   8 301.0 335 3.54 3.570 14.60  0  1    5    8
#> Cadillac Fleetwood  10.4   8 472.0 205 2.93 5.250 17.98  0  0    3    4
#> Hornet Sportabout   18.7   8 360.0 175 3.15 3.440 17.02  0  0    3    2
#> Ford Pantera L      15.8   8 351.0 264 4.22 3.170 14.50  0  1    5    4
#> Chrysler Imperial   14.7   8 440.0 230 3.23 5.345 17.42  0  0    3    4
#> Pontiac Firebird    19.2   8 400.0 175 3.08 3.845 17.05  0  0    3    2
#> Camaro Z28          13.3   8 350.0 245 3.73 3.840 15.41  0  0    3    4
#> Lincoln Continental 10.4   8 460.0 215 3.00 5.424 17.82  0  0    3    4
#> Merc 450SLC         15.2   8 275.8 180 3.07 3.780 18.00  0  0    3    3
#> Dodge Challenger    15.5   8 318.0 150 2.76 3.520 16.87  0  0    3    2
#> Merc 450SE          16.4   8 275.8 180 3.07 4.070 17.40  0  0    3    3
#> AMC Javelin         15.2   8 304.0 150 3.15 3.435 17.30  0  0    3    2
#> 
#> $`3`
#>                 mpg cyl  disp  hp drat    wt  qsec vs am gear carb
#> Valiant        18.1   6 225.0 105 2.76 3.460 20.22  1  0    3    1
#> Merc 280       19.2   6 167.6 123 3.92 3.440 18.30  1  0    4    4
#> Hornet 4 Drive 21.4   6 258.0 110 3.08 3.215 19.44  1  0    3    1
#> Ferrari Dino   19.7   6 145.0 175 3.62 2.770 15.50  0  1    5    6
#> Mazda RX4 Wag  21.0   6 160.0 110 3.90 2.875 17.02  0  1    4    4
#> Mazda RX4      21.0   6 160.0 110 3.90 2.620 16.46  0  1    4    4
#> Merc 280C      17.8   6 167.6 123 3.92 3.440 18.90  1  0    4    4

# other side
expanded_fold %>%
  pull(splits) %>%
  map(analysis)
#> $`1`
#>                      mpg cyl  disp  hp drat    wt  qsec vs am gear carb
#> Hornet 4 Drive      21.4   6 258.0 110 3.08 3.215 19.44  1  0    3    1
#> Ford Pantera L      15.8   8 351.0 264 4.22 3.170 14.50  0  1    5    4
#> Ferrari Dino        19.7   6 145.0 175 3.62 2.770 15.50  0  1    5    6
#> Mazda RX4           21.0   6 160.0 110 3.90 2.620 16.46  0  1    4    4
#> AMC Javelin         15.2   8 304.0 150 3.15 3.435 17.30  0  0    3    2
#> Hornet Sportabout   18.7   8 360.0 175 3.15 3.440 17.02  0  0    3    2
#> Valiant             18.1   6 225.0 105 2.76 3.460 20.22  1  0    3    1
#> Pontiac Firebird    19.2   8 400.0 175 3.08 3.845 17.05  0  0    3    2
#> Merc 450SLC         15.2   8 275.8 180 3.07 3.780 18.00  0  0    3    3
#> Cadillac Fleetwood  10.4   8 472.0 205 2.93 5.250 17.98  0  0    3    4
#> Chrysler Imperial   14.7   8 440.0 230 3.23 5.345 17.42  0  0    3    4
#> Maserati Bora       15.0   8 301.0 335 3.54 3.570 14.60  0  1    5    8
#> Camaro Z28          13.3   8 350.0 245 3.73 3.840 15.41  0  0    3    4
#> Dodge Challenger    15.5   8 318.0 150 2.76 3.520 16.87  0  0    3    2
#> Mazda RX4 Wag       21.0   6 160.0 110 3.90 2.875 17.02  0  1    4    4
#> Lincoln Continental 10.4   8 460.0 215 3.00 5.424 17.82  0  0    3    4
#> Merc 450SL          17.3   8 275.8 180 3.07 3.730 17.60  0  0    3    3
#> Merc 450SE          16.4   8 275.8 180 3.07 4.070 17.40  0  0    3    3
#> Duster 360          14.3   8 360.0 245 3.21 3.570 15.84  0  0    3    4
#> Merc 280            19.2   6 167.6 123 3.92 3.440 18.30  1  0    4    4
#> Merc 280C           17.8   6 167.6 123 3.92 3.440 18.90  1  0    4    4
#> 
#> $`2`
#>                 mpg cyl  disp  hp drat    wt  qsec vs am gear carb
#> Merc 280C      17.8   6 167.6 123 3.92 3.440 18.90  1  0    4    4
#> Merc 230       22.8   4 140.8  95 3.92 3.150 22.90  1  0    4    2
#> Ferrari Dino   19.7   6 145.0 175 3.62 2.770 15.50  0  1    5    6
#> Lotus Europa   30.4   4  95.1 113 3.77 1.513 16.90  1  1    5    2
#> Fiat 128       32.4   4  78.7  66 4.08 2.200 19.47  1  1    4    1
#> Mazda RX4      21.0   6 160.0 110 3.90 2.620 16.46  0  1    4    4
#> Hornet 4 Drive 21.4   6 258.0 110 3.08 3.215 19.44  1  0    3    1
#> Volvo 142E     21.4   4 121.0 109 4.11 2.780 18.60  1  1    4    2
#> Mazda RX4 Wag  21.0   6 160.0 110 3.90 2.875 17.02  0  1    4    4
#> Toyota Corolla 33.9   4  71.1  65 4.22 1.835 19.90  1  1    4    1
#> Valiant        18.1   6 225.0 105 2.76 3.460 20.22  1  0    3    1
#> Datsun 710     22.8   4 108.0  93 3.85 2.320 18.61  1  1    4    1
#> Toyota Corona  21.5   4 120.1  97 3.70 2.465 20.01  1  0    3    1
#> Merc 280       19.2   6 167.6 123 3.92 3.440 18.30  1  0    4    4
#> Porsche 914-2  26.0   4 120.3  91 4.43 2.140 16.70  0  1    5    2
#> Merc 240D      24.4   4 146.7  62 3.69 3.190 20.00  1  0    4    2
#> Honda Civic    30.4   4  75.7  52 4.93 1.615 18.52  1  1    4    2
#> Fiat X1-9      27.3   4  79.0  66 4.08 1.935 18.90  1  1    4    1
#> 
#> $`3`
#>                      mpg cyl  disp  hp drat    wt  qsec vs am gear carb
#> Cadillac Fleetwood  10.4   8 472.0 205 2.93 5.250 17.98  0  0    3    4
#> Pontiac Firebird    19.2   8 400.0 175 3.08 3.845 17.05  0  0    3    2
#> Camaro Z28          13.3   8 350.0 245 3.73 3.840 15.41  0  0    3    4
#> Fiat X1-9           27.3   4  79.0  66 4.08 1.935 18.90  1  1    4    1
#> Dodge Challenger    15.5   8 318.0 150 2.76 3.520 16.87  0  0    3    2
#> Lotus Europa        30.4   4  95.1 113 3.77 1.513 16.90  1  1    5    2
#> Merc 240D           24.4   4 146.7  62 3.69 3.190 20.00  1  0    4    2
#> Toyota Corolla      33.9   4  71.1  65 4.22 1.835 19.90  1  1    4    1
#> Maserati Bora       15.0   8 301.0 335 3.54 3.570 14.60  0  1    5    8
#> Lincoln Continental 10.4   8 460.0 215 3.00 5.424 17.82  0  0    3    4
#> Porsche 914-2       26.0   4 120.3  91 4.43 2.140 16.70  0  1    5    2
#> Toyota Corona       21.5   4 120.1  97 3.70 2.465 20.01  1  0    3    1
#> Datsun 710          22.8   4 108.0  93 3.85 2.320 18.61  1  1    4    1
#> Ford Pantera L      15.8   8 351.0 264 4.22 3.170 14.50  0  1    5    4
#> Fiat 128            32.4   4  78.7  66 4.08 2.200 19.47  1  1    4    1
#> AMC Javelin         15.2   8 304.0 150 3.15 3.435 17.30  0  0    3    2
#> Honda Civic         30.4   4  75.7  52 4.93 1.615 18.52  1  1    4    2
#> Merc 230            22.8   4 140.8  95 3.92 3.150 22.90  1  0    4    2
#> Duster 360          14.3   8 360.0 245 3.21 3.570 15.84  0  0    3    4
#> Merc 450SLC         15.2   8 275.8 180 3.07 3.780 18.00  0  0    3    3
#> Merc 450SE          16.4   8 275.8 180 3.07 4.070 17.40  0  0    3    3
#> Volvo 142E          21.4   4 121.0 109 4.11 2.780 18.60  1  1    4    2
#> Merc 450SL          17.3   8 275.8 180 3.07 3.730 17.60  0  0    3    3
#> Chrysler Imperial   14.7   8 440.0 230 3.23 5.345 17.42  0  0    3    4
#> Hornet Sportabout   18.7   8 360.0 175 3.15 3.440 17.02  0  0    3    2

Also available in this gist.

Tried to make a multi-variable version, but it's a lot harder to get "indices of rows in tibble x that match tibble y" than I expected; dplyr pushes that all down into Rcpp-land for the *_join functions :(

topepo commented 6 years ago

Yes, this seems to be the same as #23. caret has a similar function called groupKFold.

To do that now I have to k-fold on just distinct(mtcars, cyl), and then do something hacky to "expand" those folds.

Yep! I'll get the unique values of the variable used for splitting and perform V-fold CV on those values, then translate that to rows of the original data. Doing it this way means that, if you have a large number of values in the splitting variable, you don't have to have a separate split for each value.

I'll close this. If you disagree with the equivalence of the two issues, please reopen.

topepo commented 6 years ago

Give the function in devel (group_vfold_cv) a try and see if it does what you need.

ClaytonJY commented 6 years ago

@topepo that does it!

bonus question: do the longer-term plans for this package include using tidyeval, so that quoting of the group variable is optional?

ClaytonJY commented 6 years ago

extra bonus: would it make sense to allow for multiple grouping variables, e.g. group_vfold_cv(mtcars, c("cyl", "vs"), 4)? Would allow user to avoid fusing multiple group-defining vars into a new singular group variable.

github-actions[bot] commented 3 years ago

This issue has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex https://reprex.tidyverse.org) and link to this issue.