tidymodels / hardhat

Construct Modeling Packages
https://hardhat.tidymodels.org
Other
101 stars 15 forks source link

importance weights not compatible with DALEXtra::model_profile #242

Closed jamesgrecian closed 7 months ago

jamesgrecian commented 1 year ago

I've added importance weights to a logistic regression using hardhat::importance_weights. However, when I try to generate partial dependence plots for the regression using DALEXtra::model_profile it returns an error.

Typically when we generate predictions from a fitted model we don't need to use weights, so I'm not sure why this is returning an error. Unless it's simply that DALEXtra doesn't know how to deal with a column formatted as <importance_weights>?

Here's a reprex with a dummy dataset. I'm trying to extract partial dependence profiles for each fold of the dataset to visually validate the model fit as discussed here #tidymodels/planning/issues/26. The issue may be complicated as I'm calculating the weights on the fly, depending on the number of points assigned to each fold as discussed here #240.

set.seed(1107)

# packages
library(sf)
#> Linking to GEOS 3.11.0, GDAL 3.5.3, PROJ 9.1.0; sf_use_s2() is TRUE
library(tidymodels)
library(spatialsample)
library(DALEXtra)
#> Loading required package: DALEX
#> Welcome to DALEX (version: 2.4.3).
#> Find examples and detailed introduction at: http://ema.drwhy.ai/
#> Additional features will be available after installation of: ggpubr.
#> Use 'install_dependencies()' to get all suggested dependencies
#> 
#> Attaching package: 'DALEX'
#> The following object is masked from 'package:dplyr':
#> 
#>     explain
#> Anaconda not found on your computer. Conda related functionality such as create_env.R and condaenv and yml parameters from explain_scikitlearn will not be available

## Data prep:
# pak::pkg_install("Nowosad/spDataLarge")
data("lsl", "study_mask", package = "spDataLarge")
ta <- terra::rast(system.file("raster/ta.tif", package = "spDataLarge"))
lsl <- lsl |> 
  st_as_sf(coords = c("x", "y"), crs = "EPSG:32717")

# convert to 0, 1 as is typical in species distribution modelling
lsl <- lsl |> 
  mutate(lslpts = factor(as.numeric(lslpts)-1)) |>
  # Creating a dummy case weights column, to get past initial verification by recipe
  mutate(cwts = hardhat::importance_weights(NA))

# set up case weights as a recipe step
lsl_recipe <- recipes::recipe(
  lslpts ~ slope + cplan + cprof + elev + log10_carea, 
  data = sf::st_drop_geometry(lsl)
) |> 
  recipes::step_mutate(
    cwts = hardhat::importance_weights(
      ifelse(lslpts == 1, 1, sum(lslpts == 1) / sum(lslpts == 0))
    ),
    # Need to set the "case_weights" role explicitly:
    role = "case_weights"
  )

# split into folds
lsl_folds <- spatial_block_cv(lsl, method = "random", v = 10)

# try GLM
glm_model <- logistic_reg() |> 
  set_engine("glm") |> 
  set_mode("classification")

# Using weights instead: no add_formula, because the formula is in our recipe
glm_wflow_wts <- workflow(preprocessor = lsl_recipe) |> 
  add_model(glm_model) |> 
  add_case_weights(cwts)

# fit model to one fold of the data
glm_fold_fit <- glm_wflow_wts |> fit(lsl_folds$splits[[1]] |> analysis())

# generate partial dependence profile for model
# ideally want to generate profile for each fold to verify model fit
glm_explainer <- explain_tidymodels(glm_fold_fit,
                                    data = lsl_folds$splits[[1]] |> 
                                      analysis() |> 
                                      st_drop_geometry() |> 
                                      dplyr::select(slope, cplan, cprof, elev, log10_carea),
                                    y = lsl_folds$splits[[1]] |> 
                                      analysis() |> 
                                      st_drop_geometry() |> 
                                      pull(lslpts)) |>
  model_profile(N = 100, type = "partial")
#> Preparation of a new explainer is initiated
#>   -> model label       :  workflow  (  default  )
#>   -> data              :  308  rows  5  cols 
#>   -> target variable   :  308  values 
#>   -> predict function  :  yhat.workflow  will be used (  default  )
#>   -> predicted values  :  No value for predict function target column. (  default  )
#>   -> model_info        :  package tidymodels , ver. 1.0.0 , task classification (  default  ) 
#>   -> model_info        :  Model info detected classification task but 'y' is a factor .  (  WARNING  )
#>   -> model_info        :  By deafult classification tasks supports only numercical 'y' parameter. 
#>   -> model_info        :  Consider changing to numerical vector with 0 and 1 values.
#>   -> model_info        :  Otherwise I will not be able to calculate residuals or loss function.
#>   -> predicted values  :  the predict_function returns an error when executed (  WARNING  ) 
#>   -> residual function :  difference between y and yhat (  default  )
#>   -> residuals         :  the residual_function returns an error when executed (  WARNING  ) 
#>   A new explainer has been created!
#> Error in `dplyr::mutate()`:
#> ℹ In argument: `cwts = hardhat::importance_weights(...)`.
#> Caused by error in `ifelse()`:
#> ! object 'lslpts' not found
#> Backtrace:
#>      ▆
#>   1. ├─DALEX::model_profile(...)
#>   2. │ ├─ingredients::ceteris_paribus(...)
#>   3. │ └─ingredients:::ceteris_paribus.explainer(...)
#>   4. │   └─ingredients:::ceteris_paribus.default(...)
#>   5. │     ├─ingredients:::calculate_variable_profile(...)
#>   6. │     └─ingredients:::calculate_variable_profile.default(...)
#>   7. │       └─base::lapply(...)
#>   8. │         └─ingredients (local) FUN(X[[i]], ...)
#>   9. │           ├─DALEX (local) predict_function(model, new_data, ...)
#>  10. │           └─DALEXtra:::yhat.workflow(model, new_data, ...)
#>  11. │             ├─base::as.matrix(predict(X.model, newdata, type = "prob"))
#>  12. │             ├─stats::predict(X.model, newdata, type = "prob")
#>  13. │             └─workflows:::predict.workflow(X.model, newdata, type = "prob")
#>  14. │               └─workflows:::forge_predictors(new_data, workflow)
#>  15. │                 ├─hardhat::forge(new_data, blueprint = mold$blueprint)
#>  16. │                 └─hardhat:::forge.data.frame(new_data, blueprint = mold$blueprint)
#>  17. │                   ├─hardhat::run_forge(blueprint, new_data = new_data, outcomes = outcomes)
#>  18. │                   └─hardhat:::run_forge.default_recipe_blueprint(...)
#>  19. │                     └─hardhat:::forge_recipe_default_process(...)
#>  20. │                       ├─recipes::bake(object = rec, new_data = new_data)
#>  21. │                       └─recipes:::bake.recipe(object = rec, new_data = new_data)
#>  22. │                         ├─recipes::bake(step, new_data = new_data)
#>  23. │                         └─recipes:::bake.step_mutate(step, new_data = new_data)
#>  24. │                           ├─dplyr::mutate(new_data, !!!object$inputs)
#>  25. │                           └─dplyr:::mutate.data.frame(new_data, !!!object$inputs)
#>  26. │                             └─dplyr:::mutate_cols(.data, dplyr_quosures(...), by)
#>  27. │                               ├─base::withCallingHandlers(...)
#>  28. │                               └─dplyr:::mutate_col(dots[[i]], data, mask, new_columns)
#>  29. │                                 └─mask$eval_all_mutate(quo)
#>  30. │                                   └─dplyr (local) eval()
#>  31. ├─hardhat::importance_weights(...)
#>  32. │ └─hardhat:::vec_cast_named(x, to = double(), x_arg = "x")
#>  33. │   └─vctrs::vec_cast(x, to, ..., call = call)
#>  34. ├─base::ifelse(lslpts == 1, 1, sum(lslpts == 1)/sum(lslpts == 0))
#>  35. └─base::.handleSimpleError(...)
#>  36.   └─dplyr (local) h(simpleError(msg, call))
#>  37.     └─rlang::abort(message, class = error_class, parent = parent, call = error_call)

Created on 2023-05-30 with reprex v2.0.2

topepo commented 7 months ago

Your formula lslpts ~ slope + cplan + cprof + elev + log10_carea has lslpts as the outcome and step_mutate() is using it to construct the case weights.

tidymodels enforces the constraint that the outcome should not be used (in any way) when making predictions. Even if that column is available as prediction-time. This is to eliminate information leakage. It specifically excludes the outcome column(s) during prediction.

I would try using skip = TRUE so that the step does not execute outside of processing the training set.

topepo commented 7 months ago

Here's a smaller reprex:

library(tidymodels)

tidymodels_prefer()
theme_set(theme_bw())
options(pillar.advice = FALSE, pillar.min_title_chars = Inf)

mtcar_wts <- 
  mtcars %>% 
  mutate(case_wts = hardhat::importance_weights(NA))

car_rec <- 
  recipe(mpg ~ wt + disp + gear, data = mtcar_wts) %>% 
  step_mutate(case_wts = hardhat::importance_weights(1 / mpg), role = "case_weights")

lm_fit <- 
  car_rec %>% 
  workflow(linear_reg()) %>% 
  add_case_weights(case_wts) %>% 
  fit(mtcar_wts)

lm_fit %>% 
  extract_fit_engine() %>% 
  coef()
#> (Intercept)          wt        disp        gear 
#> 34.32324535 -2.69696818 -0.02006457 -0.34745007

predict(lm_fit, mtcar_wts[1:3,])
#> Error in `dplyr::mutate()`:
#> ℹ In argument: `case_wts = hardhat::importance_weights(1/mpg)`.
#> Caused by error in `FUN()`:
#> ! non-numeric argument to binary operator
#> Backtrace:
#>      ▆
#>   1. ├─stats::predict(lm_fit, mtcar_wts[1:3, ])
#>   2. ├─workflows:::predict.workflow(lm_fit, mtcar_wts[1:3, ])
#>   3. │ └─workflows:::forge_predictors(new_data, workflow) at workflows/R/predict.R:63:3
#>   4. │   ├─hardhat::forge(new_data, blueprint = mold$blueprint) at workflows/R/predict.R:70:3
#>   5. │   └─hardhat:::forge.data.frame(new_data, blueprint = mold$blueprint) at hardhat/R/forge.R:68:3
#>   6. │     ├─hardhat::run_forge(blueprint, new_data = new_data, outcomes = outcomes) at hardhat/R/forge.R:81:3
#>   7. │     └─hardhat:::run_forge.default_recipe_blueprint(...) at hardhat/R/forge.R:135:3
#>   8. │       └─hardhat:::forge_recipe_default_process(...) at hardhat/R/blueprint-recipe-default.R:350:3
#>   9. │         ├─recipes::bake(object = rec, new_data = new_data) at hardhat/R/blueprint-recipe-default.R:435:3
#>  10. │         └─recipes:::bake.recipe(object = rec, new_data = new_data)
#>  11. │           ├─recipes::bake(step, new_data = new_data)
#>  12. │           └─recipes:::bake.step_mutate(step, new_data = new_data)
#>  13. │             ├─dplyr::mutate(new_data, !!!object$inputs)
#>  14. │             └─dplyr:::mutate.data.frame(new_data, !!!object$inputs)
#>  15. │               └─dplyr:::mutate_cols(.data, dplyr_quosures(...), by)
#>  16. │                 ├─base::withCallingHandlers(...)
#>  17. │                 └─dplyr:::mutate_col(dots[[i]], data, mask, new_columns)
#>  18. │                   └─mask$eval_all_mutate(quo)
#>  19. │                     └─dplyr (local) eval()
#>  20. ├─hardhat::importance_weights(1/mpg)
#>  21. │ └─hardhat:::vec_cast_named(x, to = double(), x_arg = "x") at hardhat/R/case-weights.R:31:3
#>  22. │   └─vctrs::vec_cast(x, to, ..., call = call) at hardhat/R/util.R:245:3
#>  23. ├─base::Ops.data.frame(1, mpg) at hardhat/R/case-weights.R:31:3
#>  24. │ └─base::eval(f)
#>  25. │   └─base::eval(f)
#>  26. └─base::.handleSimpleError(...)
#>  27.   └─dplyr (local) h(simpleError(msg, call))
#>  28.     └─rlang::abort(message, class = error_class, parent = parent, call = error_call)

car_skip_rec <- 
  recipe(mpg ~ wt + disp + gear, data = mtcar_wts) %>% 
  step_mutate(case_wts = hardhat::importance_weights(1 / mpg), 
              role = "case_weights", skip = TRUE)

lm_skip_fit <- 
  car_skip_rec %>% 
  workflow(linear_reg()) %>% 
  add_case_weights(case_wts) %>% 
  fit(mtcar_wts)

lm_skip_fit %>% 
  extract_fit_engine() %>% 
  coef()
#> (Intercept)          wt        disp        gear 
#> 34.32324535 -2.69696818 -0.02006457 -0.34745007

predict(lm_skip_fit, mtcar_wts[1:3,])
#> # A tibble: 3 × 1
#>   .pred
#>   <dbl>
#> 1  22.7
#> 2  22.0
#> 3  24.5

Created on 2024-01-31 with reprex v2.0.2