spsanderson / tidyAML

Auto ML for the tidyverse
http://www.spsanderson.com/tidyAML/
Other
64 stars 7 forks source link

Make `create_splits()` function #6

Closed spsanderson closed 1 year ago

spsanderson commented 1 year ago

This function will create a splits object for any modeling issue by providing a character vector of the type you want along with some necessary arguments that will get passed to the appropriate rsample function and return the appropriate splits object.

spsanderson commented 1 year ago

Function:

create_splits <- function(.data, .split_type = "initial_split", 
                          .split_args = NULL){

  # Tidyeval ----
  split_type <- tolower(as.character(.split_type)) 

  # Checks ----
  if (is.null(.split_args)){
    .split_args <- list()
  }

  if (!inherits(.split_args, "list")){
    rlang::abort(
      message = "'.split_args' must be a 'list' of cross validation arguments.",
      use_cli_format = TRUE
    )
  }

  # Manipulation ----
  split_func <- utils::getFromNamespace(split_type, asNamespace("rsample"))
  splits_obj <- do.call(split_func, append(list(data = .data), .split_args))

  splits_tbl <- dplyr::tibble(splits = list(splits_obj), id = .split_type)

  # Return ----
  return(list(splits = splits_obj, split_type = .split_type))
}

Example:


> create_splits(mtcars, .split_type = "vfold_cv")
$splits
#  10-fold cross-validation 
# A tibble: 10 × 2
   splits         id    
   <list>         <chr> 
 1 <split [28/4]> Fold01
 2 <split [28/4]> Fold02
 3 <split [29/3]> Fold03
 4 <split [29/3]> Fold04
 5 <split [29/3]> Fold05
 6 <split [29/3]> Fold06
 7 <split [29/3]> Fold07
 8 <split [29/3]> Fold08
 9 <split [29/3]> Fold09
10 <split [29/3]> Fold10

$split_type
[1] "vfold_cv"
``