Closed jamesgrecian closed 7 months ago
Your formula lslpts ~ slope + cplan + cprof + elev + log10_carea
has lslpts
as the outcome and step_mutate()
is using it to construct the case weights.
tidymodels enforces the constraint that the outcome should not be used (in any way) when making predictions. Even if that column is available as prediction-time. This is to eliminate information leakage. It specifically excludes the outcome column(s) during prediction.
I would try using skip = TRUE
so that the step does not execute outside of processing the training set.
Here's a smaller reprex:
library(tidymodels)
tidymodels_prefer()
theme_set(theme_bw())
options(pillar.advice = FALSE, pillar.min_title_chars = Inf)
mtcar_wts <-
mtcars %>%
mutate(case_wts = hardhat::importance_weights(NA))
car_rec <-
recipe(mpg ~ wt + disp + gear, data = mtcar_wts) %>%
step_mutate(case_wts = hardhat::importance_weights(1 / mpg), role = "case_weights")
lm_fit <-
car_rec %>%
workflow(linear_reg()) %>%
add_case_weights(case_wts) %>%
fit(mtcar_wts)
lm_fit %>%
extract_fit_engine() %>%
coef()
#> (Intercept) wt disp gear
#> 34.32324535 -2.69696818 -0.02006457 -0.34745007
predict(lm_fit, mtcar_wts[1:3,])
#> Error in `dplyr::mutate()`:
#> ℹ In argument: `case_wts = hardhat::importance_weights(1/mpg)`.
#> Caused by error in `FUN()`:
#> ! non-numeric argument to binary operator
#> Backtrace:
#> ▆
#> 1. ├─stats::predict(lm_fit, mtcar_wts[1:3, ])
#> 2. ├─workflows:::predict.workflow(lm_fit, mtcar_wts[1:3, ])
#> 3. │ └─workflows:::forge_predictors(new_data, workflow) at workflows/R/predict.R:63:3
#> 4. │ ├─hardhat::forge(new_data, blueprint = mold$blueprint) at workflows/R/predict.R:70:3
#> 5. │ └─hardhat:::forge.data.frame(new_data, blueprint = mold$blueprint) at hardhat/R/forge.R:68:3
#> 6. │ ├─hardhat::run_forge(blueprint, new_data = new_data, outcomes = outcomes) at hardhat/R/forge.R:81:3
#> 7. │ └─hardhat:::run_forge.default_recipe_blueprint(...) at hardhat/R/forge.R:135:3
#> 8. │ └─hardhat:::forge_recipe_default_process(...) at hardhat/R/blueprint-recipe-default.R:350:3
#> 9. │ ├─recipes::bake(object = rec, new_data = new_data) at hardhat/R/blueprint-recipe-default.R:435:3
#> 10. │ └─recipes:::bake.recipe(object = rec, new_data = new_data)
#> 11. │ ├─recipes::bake(step, new_data = new_data)
#> 12. │ └─recipes:::bake.step_mutate(step, new_data = new_data)
#> 13. │ ├─dplyr::mutate(new_data, !!!object$inputs)
#> 14. │ └─dplyr:::mutate.data.frame(new_data, !!!object$inputs)
#> 15. │ └─dplyr:::mutate_cols(.data, dplyr_quosures(...), by)
#> 16. │ ├─base::withCallingHandlers(...)
#> 17. │ └─dplyr:::mutate_col(dots[[i]], data, mask, new_columns)
#> 18. │ └─mask$eval_all_mutate(quo)
#> 19. │ └─dplyr (local) eval()
#> 20. ├─hardhat::importance_weights(1/mpg)
#> 21. │ └─hardhat:::vec_cast_named(x, to = double(), x_arg = "x") at hardhat/R/case-weights.R:31:3
#> 22. │ └─vctrs::vec_cast(x, to, ..., call = call) at hardhat/R/util.R:245:3
#> 23. ├─base::Ops.data.frame(1, mpg) at hardhat/R/case-weights.R:31:3
#> 24. │ └─base::eval(f)
#> 25. │ └─base::eval(f)
#> 26. └─base::.handleSimpleError(...)
#> 27. └─dplyr (local) h(simpleError(msg, call))
#> 28. └─rlang::abort(message, class = error_class, parent = parent, call = error_call)
car_skip_rec <-
recipe(mpg ~ wt + disp + gear, data = mtcar_wts) %>%
step_mutate(case_wts = hardhat::importance_weights(1 / mpg),
role = "case_weights", skip = TRUE)
lm_skip_fit <-
car_skip_rec %>%
workflow(linear_reg()) %>%
add_case_weights(case_wts) %>%
fit(mtcar_wts)
lm_skip_fit %>%
extract_fit_engine() %>%
coef()
#> (Intercept) wt disp gear
#> 34.32324535 -2.69696818 -0.02006457 -0.34745007
predict(lm_skip_fit, mtcar_wts[1:3,])
#> # A tibble: 3 × 1
#> .pred
#> <dbl>
#> 1 22.7
#> 2 22.0
#> 3 24.5
Created on 2024-01-31 with reprex v2.0.2
I've added importance weights to a logistic regression using
hardhat::importance_weights
. However, when I try to generate partial dependence plots for the regression usingDALEXtra::model_profile
it returns an error.Typically when we generate predictions from a fitted model we don't need to use weights, so I'm not sure why this is returning an error. Unless it's simply that DALEXtra doesn't know how to deal with a column formatted as
<importance_weights>
?Here's a reprex with a dummy dataset. I'm trying to extract partial dependence profiles for each fold of the dataset to visually validate the model fit as discussed here #tidymodels/planning/issues/26. The issue may be complicated as I'm calculating the weights on the fly, depending on the number of points assigned to each fold as discussed here #240.
Created on 2023-05-30 with reprex v2.0.2