Closed RMHogervorst closed 1 year ago
There is a first version here: https://github.com/RMHogervorst/templates_ml/blob/master/R/learning_curve.R. I hope it still works, because I haven't touched it in a while
Thank you so much for bring this back up! 🙌
One question I have about this approach is that I feel like I don't typically see the validation set start out really small and get bigger. I have seen approaches like mlr's where you have one validation set and you grow the training set but always validate against one reasonably sized set. I think this is how caret::learning_curve()
works too, if I understand correctly.
Do any of you all have examples of this kind of approach (with the validation/assessment set growing) being recommended or used? Thoughts on pros/cons, beyond the obvious?
hi julia, your question makes me wonder if I just didn't code it right. I looked at scikitlearn's approach and it talks about trainingset sizes not validation set sizes. Looking at octave code samples for Andrew Ng's Machinelearning course, I also see plots that describe increasing trainingsize not validation size. So the 'state of the art' in other places seems to say only increase the trainingsize.
But maybe zooming out we could come to the same conclusion: The validation set is a proxy for how well the model generalizes to new data/never seen before data. So would we gain any more knowledge from using incrementally larger validation sets? I don't think so. Your suggestion of reasonable size validationset seems good to me!
I was looking at the mlr and scikit-learn implementations again today and I think this would be most like what is implemented there:
library(tidyverse)
library(rsample)
data(wa_churn, package = "modeldata")
set.seed(13)
folds <- vfold_cv(wa_churn, v = 5)
folds
#> # 5-fold cross-validation
#> # A tibble: 5 x 2
#> splits id
#> <list> <chr>
#> 1 <split [5634/1409]> Fold1
#> 2 <split [5634/1409]> Fold2
#> 3 <split [5634/1409]> Fold3
#> 4 <split [5635/1408]> Fold4
#> 5 <split [5635/1408]> Fold5
remove_random <- function(split, prop) {
if (prop >= 1) {
return(split$in_id)
}
l <- length(split$in_id)
p <- round(l * (1 - prop))
split$in_id[-sample(1:l, p)]
}
folds_parsed <- folds %>%
crossing(prop = c(0.3, 0.5, 0.7, 0.9, 1.0)) %>%
mutate(analysis = map2(splits, prop, remove_random),
assessment = map(splits, complement))
folds_parsed
#> # A tibble: 25 x 5
#> splits id prop analysis assessment
#> <list> <chr> <dbl> <list> <list>
#> 1 <split [5634/1409]> Fold1 0.3 <int [1,690]> <int [1,409]>
#> 2 <split [5634/1409]> Fold1 0.5 <int [2,817]> <int [1,409]>
#> 3 <split [5634/1409]> Fold1 0.7 <int [3,944]> <int [1,409]>
#> 4 <split [5634/1409]> Fold1 0.9 <int [5,071]> <int [1,409]>
#> 5 <split [5634/1409]> Fold1 1 <int [5,634]> <int [1,409]>
#> 6 <split [5634/1409]> Fold2 0.3 <int [1,690]> <int [1,409]>
#> 7 <split [5634/1409]> Fold2 0.5 <int [2,817]> <int [1,409]>
#> 8 <split [5634/1409]> Fold2 0.7 <int [3,944]> <int [1,409]>
#> 9 <split [5634/1409]> Fold2 0.9 <int [5,071]> <int [1,409]>
#> 10 <split [5634/1409]> Fold2 1 <int [5,634]> <int [1,409]>
#> # … with 15 more rows
folds_parsed %>%
mutate(splits = map2(analysis,
assessment,
~make_splits(list(analysis = .x, assessment = .y),
wa_churn))) %>%
select(prop, splits) %>%
nest(learning_splits = c(splits)) %>%
mutate(learning_splits = map(learning_splits, manual_rset, paste0("LearningFold", 1:5)))
#> # A tibble: 5 x 2
#> prop learning_splits
#> <dbl> <list>
#> 1 0.3 <manual_rset [5 × 2]>
#> 2 0.5 <manual_rset [5 × 2]>
#> 3 0.7 <manual_rset [5 × 2]>
#> 4 0.9 <manual_rset [5 × 2]>
#> 5 1 <manual_rset [5 × 2]>
Created on 2021-07-19 by the reprex package (v2.0.0)
You should be able to use fit_resamples()
or another function from tune via purrr::map()
across those rset
objects. Notice that the intermediate objects that have the crossing of prop
and splits
are not an rset
because you don't want to aggregate over all those splits (you only want to aggregate over the same value of prop
). You'll need to use some purrr::map()
here to do any evaluating, but this does get you the right analysis/assessment sets for a learning curve with cross-validation.
I'm not entirely convinced that we should have code to create this kind of object in rsample itself, since folks can do this with make_splits()
and manual_rset()
. Let's leave this example code here and gather more feedback.
For anyone in the meantime who comes by and is looking for actually to fit across those manual rset
objects:
library(tidyverse)
library(tidymodels)
#> Registered S3 method overwritten by 'tune':
#> method from
#> required_pkgs.model_spec parsnip
data(wa_churn, package = "modeldata")
set.seed(13)
folds <- vfold_cv(wa_churn, v = 5)
folds
#> # 5-fold cross-validation
#> # A tibble: 5 x 2
#> splits id
#> <list> <chr>
#> 1 <split [5634/1409]> Fold1
#> 2 <split [5634/1409]> Fold2
#> 3 <split [5634/1409]> Fold3
#> 4 <split [5635/1408]> Fold4
#> 5 <split [5635/1408]> Fold5
remove_random <- function(split, prop) {
if (prop >= 1) {
return(split$in_id)
}
l <- length(split$in_id)
p <- round(l * (1 - prop))
split$in_id[-sample(1:l, p)]
}
learning_splits <- folds %>%
crossing(prop = c(0.3, 0.5, 0.7, 0.9, 1.0)) %>%
mutate(analysis = map2(splits, prop, remove_random),
assessment = map(splits, complement),
splits = map2(analysis,
assessment,
~make_splits(list(analysis = .x, assessment = .y),
wa_churn))) %>%
select(prop, splits) %>%
nest(learning_splits = c(splits)) %>%
mutate(learning_splits = map(learning_splits, manual_rset, paste0("LearningFold", 1:5)))
churn_form <- churn ~ monthly_charges + tenure + contract
churn_spec <- rand_forest() %>% set_mode("classification")
wf <- workflow(churn_form, churn_spec)
doParallel::registerDoParallel()
learning_res <-
learning_splits %>%
mutate(res = map(learning_splits, ~fit_resamples(wf, .)),
metrics = map(res, collect_metrics)) %>%
unnest(metrics)
learning_res %>%
ggplot(aes(prop, mean, color = .metric)) +
geom_ribbon(aes(ymin = mean - std_err,
ymax = mean + std_err), alpha = 0.3, color = NA) +
geom_line(alpha = 0.8) +
geom_point(size = 2) +
facet_wrap(~.metric, ncol = 1, scales = "free_y") +
theme(legend.position = "none") +
labs(y = NULL)
Created on 2021-07-19 by the reprex package (v2.0.0)
I should point out that these plots only show the metrics measured via cross-validation on heldout/assessment data. We do make it a bit hard to predict against the training data (on purpose) but I'm sure it's possible to wrangle that through too.
This seems like an elegant solution! I can def. live with this.
Thank you all for the discussion!
Given that this issue has not gathered a huge amount of feedback in the past 2 years that would suggest this is a pressing need, I think it's okay to take the purrr::map()
approach that Julia outlined above. I'm gonna go ahead and close this issue since we're not going to implement this soon, by the looks of it.
This issue has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue.
In the machine learning course by Andrew Ng on Coursera, Andrew talks about creating learning curves for your machine learning problem to identify if your model is underfitting or overfitting. Basically the approach is:
I think this fits quite nicely in the 'rsample' framework and I hacked something together that returns incrementally larger data slices while never mixing training and validation set:
You can than use ggplot2 to plot the performance and identify under or over fit.
Is this something of broad enough interest to add to rsample? I'm happy to submit a PR.