tidymodels / hardhat

Construct Modeling Packages
https://hardhat.tidymodels.org
Other
101 stars 16 forks source link

Following R's default dummy variables creation #137

Closed mladenjovanovic closed 4 years ago

mladenjovanovic commented 4 years ago

How can I force 'normal' dummy variable creation using mold/forge? I am referring to expanding to K-1 levels, rather than complete expansion?

For example:

hardhat::mold(Sepal.Width ~ Species, iris)

vs.

model.matrix(Sepal.Width ~ Species, iris)

DavisVaughan commented 4 years ago

You need to tweak the blueprint to contain an intercept, which mold() doesn't do by default for two reasons:

library(hardhat)

bp <- default_formula_blueprint(intercept = TRUE)

mold(Sepal.Width ~ Species, iris, blueprint = bp)$predictors
#> # A tibble: 150 x 3
#>    `(Intercept)` Speciesversicolor Speciesvirginica
#>            <dbl>             <dbl>            <dbl>
#>  1             1                 0                0
#>  2             1                 0                0
#>  3             1                 0                0
#>  4             1                 0                0
#>  5             1                 0                0
#>  6             1                 0                0
#>  7             1                 0                0
#>  8             1                 0                0
#>  9             1                 0                0
#> 10             1                 0                0
#> # … with 140 more rows

head(model.matrix(Sepal.Width ~ Species, iris))
#>   (Intercept) Speciesversicolor Speciesvirginica
#> 1           1                 0                0
#> 2           1                 0                0
#> 3           1                 0                0
#> 4           1                 0                0
#> 5           1                 0                0
#> 6           1                 0                0

Note that hardhat still just uses model.matrix() under the hood, it just does something like this

# model matrix does this too when there is no intercept
head(model.matrix(Sepal.Width ~ Species + 0, iris))
#>   Speciessetosa Speciesversicolor Speciesvirginica
#> 1             1                 0                0
#> 2             1                 0                0
#> 3             1                 0                0
#> 4             1                 0                0
#> 5             1                 0                0
#> 6             1                 0                0
mladenjovanovic commented 4 years ago

Thanks a lot! I have sorted it out with a few tweaks. Here is what I was trying to do: I wrote a 'wrapper' for my package that performs CV of a provided model(ing) function and performance metrics function. This is by default lm. Later, I could use new_data and model to do PDP+ICE plots or feature importance, etc.

Since now the predictors have intercept column, I had to do the following to have exactly the same behavior as lm:

# ------------------------------------------------------------------------------
#' Simple linear regression model
#'
#' This model uses all \code{predictors} to model the \code{outcome} with
#'     \code{\link[stats]{lm}} function
#' @inheritParams basic_arguments
#' @returns \code{model} object
#' @export
#' @examples
#' lm_model(
#'   predictors = iris[2:3],
#'   outcome = iris[[1]]
#' )
lm_model <- function(predictors,
                     outcome,
                     SESOI_lower = 0,
                     SESOI_upper = 0,
                     na.rm = FALSE) {

  data <- cbind(.outcome = outcome, predictors)

  stats::lm(.outcome ~ . -1, data)
}

So, I have used intercept = TRUE, and -1 in the lm formula.

Maybe, just maybe, there might be a specific default_blueprint that behaves exactly like lm (i.e. format predictors the same way lm interface does with say outcome~covariate * Group?

Thanks again for the help and very useful package!!!

mladenjovanovic commented 4 years ago

I also gave user option to use or not use intercept:

#' Fit a `cv_model`
#'
#' `cv_model()` fits a model.
#'
#' @param x Depending on the context:
#'
#'   * A __data frame__ of predictors.
#'   * A __matrix__ of predictors.
#'   * A __recipe__ specifying a set of preprocessing steps
#'     created from [recipes::recipe()].
#'
#' @param y When `x` is a __data frame__ or __matrix__, `y` is the outcome
#' specified as:
#'
#'   * A __data frame__ with 1 numeric column.
#'   * A __matrix__ with 1 numeric column.
#'   * A numeric __vector__.
#'
#' @param data When a __recipe__ or __formula__ is used, `data` is specified as:
#'
#'   * A __data frame__ containing both the predictors and the outcome.
#'
#' @param formula A formula specifying the outcome terms on the left-hand side,
#' and the predictor terms on the right-hand side.
#'
#' @param intercept Should intercept column be created? Default is \code{TRUE}.
#'
#' @param ... See Details
#' @details
#' Extra parameters using \code{...} are forwarded to implementation function.
#'   These parameters are the following:
#' \describe{
#'   \item{model_func}{Model function. Default is \code{\link{lm_model}}. See also
#'   \code{\link{baseline_model}}}
#'   \item{predict_func}{Predict function. Default is \code{\link{generic_predict}}}
#'   \item{perf_func}{Model performance function. Default is \code{\link{performance_metrics}}}
#'   \item{SESOI_lower}{Function or numeric scalar. Default is \code{\link{SESOI_lower_func}}}
#'   \item{SESOI_upper}{Function or numeric scalar. Default is \code{\link{SESOI_upper_func}}}
#'   \item{control}{Control structure using \code{\link{model_control}}. The parameters
#'   used in \code{cv_model} are \code{cv_folds}, and \code{cv_strata}}
#'   \item{na.rm}{Should NAs be removed? Default is FALSE. This is forwarded  to
#'   \code{model_func}, \code{predict_func}, \code{perfr_func}, \code{SESOI_lower},
#'   and \code{SESOI_upper}}
#' }
#'
#' In summary, \code{cv_model} represents a wrapper function, that performs \code{model_func} within
#'  the cross-validation loop and provide it's predictive performance metrics using \code{perf_func}
#'
#' @return
#'
#' A `bmbstats_cv_model` object.
#'
#' @examples
#' m1 <- cv_model(
#'   Sepal.Length ~ . - Species,
#'   iris
#' )
#' predict(m1, new_data = iris)
#' @export
cv_model <- function(x, ...) {
  UseMethod("cv_model")
}

#' @export
#' @rdname cv_model
cv_model.default <- function(x, ...) {
  stop("`cv_model()` is not defined for a '", class(x)[1], "'.", call. = FALSE)
}

# XY method - data frame

#' @export
#' @rdname cv_model
cv_model.data.frame <- function(x, y, intercept = TRUE, ...) {
  bp <- hardhat::default_xy_blueprint(intercept = intercept)

  processed <- hardhat::mold(x, y, blueprint = bp)
  cv_model_bridge(processed, ...)
}

# XY method - matrix

#' @export
#' @rdname cv_model
cv_model.matrix <- function(x, y, intercept = TRUE, ...) {
  bp <- hardhat::default_xy_blueprint(intercept = intercept)

  processed <- hardhat::mold(x, y, blueprint = bp)
  cv_model_bridge(processed, ...)
}

# Formula method

#' @export
#' @rdname cv_model
cv_model.formula <- function(formula, data, intercept = TRUE, ...) {
  bp <- hardhat::default_formula_blueprint(intercept = intercept)

  processed <- hardhat::mold(formula, data, blueprint = bp)
  cv_model_bridge(processed, ...)
}

# Recipe method

#' @export
#' @rdname cv_model
cv_model.recipe <- function(x, data, intercept = TRUE, ...) {
  bp <- hardhat::default_recipe_blueprint(intercept = intercept)

  processed <- hardhat::mold(x, data, blueprint = bp)
  cv_model_bridge(processed, ...)
}

# ------------------------------------------------------------------------------
# Bridge

cv_model_bridge <- function(processed, ...) {

  predictors <- processed$predictors
  outcome <- processed$outcomes

  # Validate
  hardhat::validate_outcomes_are_univariate(outcome)
  outcome <- outcome[[1]]

  fit <- cv_model_impl(predictors, outcome, ...)

  new_cv_model(
    predictors = fit$predictors,
    outcome = fit$outcome,
    model_func = fit$model_func,
    predict_func = fit$predict_func,
    perf_func = fit$perf_func,
    SESOI_lower = fit$SESOI_lower,
    SESOI_upper = fit$SESOI_upper,
    model = fit$model,
    predicted = fit$predicted,
    performance = fit$performance,
    residual = fit$residual,
    residual_magnitude = fit$residual_magnitude,
    cross_validation = fit$cross_validation,
    control = fit$control,
    na.rm = fit$na.rm,
    blueprint = processed$blueprint
  )
}
DavisVaughan commented 4 years ago

Maybe, just maybe, there might be a specific default_blueprint that behaves exactly like lm

The thing to remember here is that hardhat is generally designed for new modeling packages. Not necessarily for wrappers, although you can still make it work.

The main idea of hardhat is that it takes care of all the preprocessing work for you, so that you can finally pass the preprocessed X and Y data down to your modeling implementation to do the hard work. That modeling function shouldn't do any preprocessing.

With your setup, you create another preprocessing formula in your "engine" to be able to pass it to lm(). I understand why you do this since lm() doesn't seem to have an XY interface, but just remember that this isn't how hardhat was designed. You shouldn't be doing any preprocessing at this point.

I think it would fit better with the hardhat model to use stats::lm.fit(x, y), which can be viewed as the XY interface to lm, and won't perform any preprocessing on the data.

DavisVaughan commented 4 years ago

I also gave user option to use or not use intercept

I personally think this is the best way to expose whether or not an intercept should be in the model!

mladenjovanovic commented 4 years ago

Completely understand! I have used hardhat in dorem package for a novel model (that we use in sport science).

The stats::lm(.outcome ~ . -1, data) doesn't do any extra pre-processing, since that is done when calling the wrapper (e.g. cv_model(Sepal.Length~., iris).

Thanks again for the help

github-actions[bot] commented 3 years ago

This issue has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue.