LudvigOlsen / cvms

R Package: Cross-validate one or multiple gaussian or binomial regression models at once. Perform repeated cross-validation. Returns results in a tibble for easy comparison, reporting and further analysis.
Other
39 stars 7 forks source link

default hyperparameters #30

Closed ggrothendieck closed 1 year ago

ggrothendieck commented 1 year ago

If I write the model_fn as shown in the documentation then no matter what I try it won't use the default hyper parameter and instead gives an error.

library(cvms)
library(dplyr)
library(groupdata2)

model_fn_loess <- function(train_data, formula, hyperparameters) {
  hyperparameters <- cvms::update_hyperparameters(span = 0.75, 
    hyperparameters = hyperparameters)
  loess(formula = formula, data = train_data, span = hyperparameters[["span"]])
}
predict_loess <- function(test_data, model, formula, hyperparameters, train_data) {
  predict(model, test_data)
}

fold_data <- fold(mtcars, k = 4) %>% arrange(.folds)
formulas <- c("mpg ~ .", "mpg ~ cyl + disp + hp")

# attempt 1
outRF <- cross_validate_fn(fold_data, formulas, "gaussian", 
  model_fn = model_fn_rf, predict_fn = predict_rf)
## Will cross-validate 2 models. This requires fitting 8 model instances.
## Error in value[[3L]](cond) : ---
## cross_validate_fn(): Error: Assertion failed. One of the following must apply:
##  * checkmate::check_list(hyperparameters): Must be of type 'list', not
##  * 'NULL'
##  * checkmate::check_data_frame(hyperparameters): Must be of type
##  * 'data.frame', not 'NULL'

# attempt 2
outRF <- cross_validate_fn(fold_data, formulas, "gaussian", 
  model_fn = model_fn_rf, predict_fn = predict_rf,
  hyperparameters = list())
## Error: Assertion failed. One of the following must apply:
##  * checkmate::check_data_frame(hyperparameters): Must be of type
##  * 'data.frame' (or 'NULL'), not 'list'
##  * checkmate::check_list(hyperparameters): Must have length >= 1, but
##  * has length 0

# attempt 3
outRF <- cross_validate_fn(fold_data, formulas, "gaussian", 
  model_fn = model_fn_rf, predict_fn = predict_rf,
  hyperparameters = data.frame())
## Error: Assertion failed. One of the following must apply:
##  * checkmate::check_data_frame(hyperparameters): Must have at least 1
##  * rows, but has 0 rows
##  * checkmate::check_list(hyperparameters): Must be of type 'list' (or
##  * 'NULL'), not 'data.frame'

Writing the model_fn like this does work but I think it should not be necessary.

model_fn_loess <- function(train_data, formula, hyperparameters) {
  hyperparametrs <- if (missing(hyperparameters)) list(span = 0.75)
  else cvms::update_hyperparameters(span = 0.75, 
    hyperparameters = hyperparameters)
  loess(formula = formula, data = train_data, span = hyperparameters[["span"]])
}
LudvigOlsen commented 1 year ago

Thanks for this @ggrothendieck

I completely agree and have fixed this for an upcoming update!