tidymodels / rsample

Classes and functions to create and summarize resampling objects
https://rsample.tidymodels.org
Other
338 stars 67 forks source link

Feature request: incrementally larger training and test sets for 'learning curves' #166

Closed RMHogervorst closed 1 year ago

RMHogervorst commented 4 years ago

In the machine learning course by Andrew Ng on Coursera, Andrew talks about creating learning curves for your machine learning problem to identify if your model is underfitting or overfitting. Basically the approach is:

I think this fits quite nicely in the 'rsample' framework and I hacked something together that returns incrementally larger data slices while never mixing training and validation set:

incr_ames <- incremental_set(ames, 10,min_data_size = 25)
incr_ames
## # A tibble: 10 x 2
##    splits             id         
##    <list>             <chr>      
##  1 <split [25/25]>    Increment01
##  2 <split [266/103]>  Increment02
##  3 <split [507/181]>  Increment03
##  4 <split [748/259]>  Increment04
##  5 <split [989/337]>  Increment05
##  6 <split [1.2K/415]> Increment06
##  7 <split [1.5K/493]> Increment07
##  8 <split [1.7K/571]> Increment08
##  9 <split [2K/649]>   Increment09
## 10 <split [2.2K/727]> Increment10

You can than use ggplot2 to plot the performance and identify under or over fit. example_learning_curve

Is this something of broad enough interest to add to rsample? I'm happy to submit a PR.

RMHogervorst commented 3 years ago

There is a first version here: https://github.com/RMHogervorst/templates_ml/blob/master/R/learning_curve.R. I hope it still works, because I haven't touched it in a while

juliasilge commented 3 years ago

Thank you so much for bring this back up! 🙌

One question I have about this approach is that I feel like I don't typically see the validation set start out really small and get bigger. I have seen approaches like mlr's where you have one validation set and you grow the training set but always validate against one reasonably sized set. I think this is how caret::learning_curve() works too, if I understand correctly.

Do any of you all have examples of this kind of approach (with the validation/assessment set growing) being recommended or used? Thoughts on pros/cons, beyond the obvious?

RMHogervorst commented 3 years ago

hi julia, your question makes me wonder if I just didn't code it right. I looked at scikitlearn's approach and it talks about trainingset sizes not validation set sizes. Looking at octave code samples for Andrew Ng's Machinelearning course, I also see plots that describe increasing trainingsize not validation size. So the 'state of the art' in other places seems to say only increase the trainingsize.

But maybe zooming out we could come to the same conclusion: The validation set is a proxy for how well the model generalizes to new data/never seen before data. So would we gain any more knowledge from using incrementally larger validation sets? I don't think so. Your suggestion of reasonable size validationset seems good to me!

juliasilge commented 3 years ago

I was looking at the mlr and scikit-learn implementations again today and I think this would be most like what is implemented there:

library(tidyverse)
library(rsample)
data(wa_churn, package = "modeldata")

set.seed(13)
folds <- vfold_cv(wa_churn, v = 5)
folds
#> #  5-fold cross-validation 
#> # A tibble: 5 x 2
#>   splits              id   
#>   <list>              <chr>
#> 1 <split [5634/1409]> Fold1
#> 2 <split [5634/1409]> Fold2
#> 3 <split [5634/1409]> Fold3
#> 4 <split [5635/1408]> Fold4
#> 5 <split [5635/1408]> Fold5

remove_random <- function(split, prop) {
   if (prop >= 1) {
      return(split$in_id)
   }
   l <- length(split$in_id)
   p <- round(l * (1 - prop))
   split$in_id[-sample(1:l, p)]
}

folds_parsed <- folds %>%
   crossing(prop = c(0.3, 0.5, 0.7, 0.9, 1.0)) %>%
   mutate(analysis = map2(splits, prop, remove_random),
          assessment = map(splits, complement))

folds_parsed
#> # A tibble: 25 x 5
#>    splits              id     prop analysis      assessment   
#>    <list>              <chr> <dbl> <list>        <list>       
#>  1 <split [5634/1409]> Fold1   0.3 <int [1,690]> <int [1,409]>
#>  2 <split [5634/1409]> Fold1   0.5 <int [2,817]> <int [1,409]>
#>  3 <split [5634/1409]> Fold1   0.7 <int [3,944]> <int [1,409]>
#>  4 <split [5634/1409]> Fold1   0.9 <int [5,071]> <int [1,409]>
#>  5 <split [5634/1409]> Fold1   1   <int [5,634]> <int [1,409]>
#>  6 <split [5634/1409]> Fold2   0.3 <int [1,690]> <int [1,409]>
#>  7 <split [5634/1409]> Fold2   0.5 <int [2,817]> <int [1,409]>
#>  8 <split [5634/1409]> Fold2   0.7 <int [3,944]> <int [1,409]>
#>  9 <split [5634/1409]> Fold2   0.9 <int [5,071]> <int [1,409]>
#> 10 <split [5634/1409]> Fold2   1   <int [5,634]> <int [1,409]>
#> # … with 15 more rows

folds_parsed %>%
   mutate(splits = map2(analysis, 
                        assessment, 
                        ~make_splits(list(analysis = .x, assessment = .y), 
                                     wa_churn))) %>%
   select(prop, splits) %>%
   nest(learning_splits = c(splits)) %>%
   mutate(learning_splits = map(learning_splits, manual_rset, paste0("LearningFold", 1:5)))
#> # A tibble: 5 x 2
#>    prop learning_splits      
#>   <dbl> <list>               
#> 1   0.3 <manual_rset [5 × 2]>
#> 2   0.5 <manual_rset [5 × 2]>
#> 3   0.7 <manual_rset [5 × 2]>
#> 4   0.9 <manual_rset [5 × 2]>
#> 5   1   <manual_rset [5 × 2]>

Created on 2021-07-19 by the reprex package (v2.0.0)

You should be able to use fit_resamples() or another function from tune via purrr::map() across those rset objects. Notice that the intermediate objects that have the crossing of prop and splits are not an rset because you don't want to aggregate over all those splits (you only want to aggregate over the same value of prop). You'll need to use some purrr::map() here to do any evaluating, but this does get you the right analysis/assessment sets for a learning curve with cross-validation.

I'm not entirely convinced that we should have code to create this kind of object in rsample itself, since folks can do this with make_splits() and manual_rset(). Let's leave this example code here and gather more feedback.

juliasilge commented 3 years ago

For anyone in the meantime who comes by and is looking for actually to fit across those manual rset objects:

library(tidyverse)
library(tidymodels)
#> Registered S3 method overwritten by 'tune':
#>   method                   from   
#>   required_pkgs.model_spec parsnip
data(wa_churn, package = "modeldata")

set.seed(13)
folds <- vfold_cv(wa_churn, v = 5)
folds
#> #  5-fold cross-validation 
#> # A tibble: 5 x 2
#>   splits              id   
#>   <list>              <chr>
#> 1 <split [5634/1409]> Fold1
#> 2 <split [5634/1409]> Fold2
#> 3 <split [5634/1409]> Fold3
#> 4 <split [5635/1408]> Fold4
#> 5 <split [5635/1408]> Fold5

remove_random <- function(split, prop) {
   if (prop >= 1) {
      return(split$in_id)
   }
   l <- length(split$in_id)
   p <- round(l * (1 - prop))
   split$in_id[-sample(1:l, p)]
}

learning_splits <- folds %>%
   crossing(prop = c(0.3, 0.5, 0.7, 0.9, 1.0)) %>%
   mutate(analysis = map2(splits, prop, remove_random),
          assessment = map(splits, complement),
          splits = map2(analysis, 
                        assessment, 
                        ~make_splits(list(analysis = .x, assessment = .y), 
                                     wa_churn))) %>%
   select(prop, splits) %>%
   nest(learning_splits = c(splits)) %>%
   mutate(learning_splits = map(learning_splits, manual_rset, paste0("LearningFold", 1:5)))

churn_form <- churn ~ monthly_charges + tenure + contract
churn_spec <- rand_forest() %>% set_mode("classification")
wf <- workflow(churn_form, churn_spec)

doParallel::registerDoParallel()
learning_res <- 
   learning_splits %>%
   mutate(res = map(learning_splits, ~fit_resamples(wf, .)),
          metrics = map(res, collect_metrics)) %>%
   unnest(metrics) 

learning_res %>%
   ggplot(aes(prop, mean, color = .metric)) +
   geom_ribbon(aes(ymin = mean - std_err,
                    ymax = mean + std_err), alpha = 0.3, color = NA) +
   geom_line(alpha = 0.8) +
   geom_point(size = 2) +
   facet_wrap(~.metric, ncol = 1, scales = "free_y") +
   theme(legend.position = "none") +
   labs(y = NULL)

Created on 2021-07-19 by the reprex package (v2.0.0)

I should point out that these plots only show the metrics measured via cross-validation on heldout/assessment data. We do make it a bit hard to predict against the training data (on purpose) but I'm sure it's possible to wrangle that through too.

RMHogervorst commented 3 years ago

This seems like an elegant solution! I can def. live with this.

hfrick commented 1 year ago

Thank you all for the discussion!

Given that this issue has not gathered a huge amount of feedback in the past 2 years that would suggest this is a pressing need, I think it's okay to take the purrr::map() approach that Julia outlined above. I'm gonna go ahead and close this issue since we're not going to implement this soon, by the looks of it.

github-actions[bot] commented 1 year ago

This issue has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue.