Closed spsanderson closed 1 year ago
Function:
create_splits <- function(.data, .split_type = "initial_split",
.split_args = NULL){
# Tidyeval ----
split_type <- tolower(as.character(.split_type))
# Checks ----
if (is.null(.split_args)){
.split_args <- list()
}
if (!inherits(.split_args, "list")){
rlang::abort(
message = "'.split_args' must be a 'list' of cross validation arguments.",
use_cli_format = TRUE
)
}
# Manipulation ----
split_func <- utils::getFromNamespace(split_type, asNamespace("rsample"))
splits_obj <- do.call(split_func, append(list(data = .data), .split_args))
splits_tbl <- dplyr::tibble(splits = list(splits_obj), id = .split_type)
# Return ----
return(list(splits = splits_obj, split_type = .split_type))
}
Example:
> create_splits(mtcars, .split_type = "vfold_cv")
$splits
# 10-fold cross-validation
# A tibble: 10 × 2
splits id
<list> <chr>
1 <split [28/4]> Fold01
2 <split [28/4]> Fold02
3 <split [29/3]> Fold03
4 <split [29/3]> Fold04
5 <split [29/3]> Fold05
6 <split [29/3]> Fold06
7 <split [29/3]> Fold07
8 <split [29/3]> Fold08
9 <split [29/3]> Fold09
10 <split [29/3]> Fold10
$split_type
[1] "vfold_cv"
``
This function will create a splits object for any modeling issue by providing a character vector of the type you want along with some necessary arguments that will get passed to the appropriate
rsample
function and return the appropriate splits object.