importance weights not compatible with DALEXtra::model_profile #242

I've added importance weights to a logistic regression using hardhat::importance_weights. However, when I try to generate partial dependence plots for the regression using DALEXtra::model_profile it returns an error.

Typically when we generate predictions from a fitted model we don't need to use weights, so I'm not sure why this is returning an error. Unless it's simply that DALEXtra doesn't know how to deal with a column formatted as <importance_weights>?

Here's a reprex with a dummy dataset. I'm trying to extract partial dependence profiles for each fold of the dataset to visually validate the model fit as discussed here #tidymodels/planning/issues/26. The issue may be complicated as I'm calculating the weights on the fly, depending on the number of points assigned to each fold as discussed here #240.


# packages
## Data prep:
# pak::pkg_install("Nowosad/spDataLarge")
data("lsl", "study_mask", package = "spDataLarge")
ta <- terra::rast(system.file("raster/ta.tif", package = "spDataLarge"))
lsl <- lsl |> 
  st_as_sf(coords = c("x", "y"), crs = "EPSG:32717")

# convert to 0, 1 as is typical in species distribution modelling
lsl <- lsl |> 
  mutate(lslpts = factor(as.numeric(lslpts)-1)) |>
  # Creating a dummy case weights column, to get past initial verification by recipe
  mutate(cwts = hardhat::importance_weights(NA))

# set up case weights as a recipe step
lsl_recipe <- recipes::recipe(
  lslpts ~ slope + cplan + cprof + elev + log10_carea, 
  data = sf::st_drop_geometry(lsl)
) |> 
    cwts = hardhat::importance_weights(
      ifelse(lslpts == 1, 1, sum(lslpts == 1) / sum(lslpts == 0))
    # Need to set the "case_weights" role explicitly:
    role = "case_weights"

# split into folds
lsl_folds <- spatial_block_cv(lsl, method = "random", v = 10)

# try GLM
glm_model <- logistic_reg() |> 
  set_engine("glm") |> 

# Using weights instead: no add_formula, because the formula is in our recipe
glm_wflow_wts <- workflow(preprocessor = lsl_recipe) |> 
  add_model(glm_model) |> 

# fit model to one fold of the data
glm_fold_fit <- glm_wflow_wts |> fit(lsl_folds$splits[[1]] |> analysis())

# generate partial dependence profile for model
# ideally want to generate profile for each fold to verify model fit
glm_explainer <- explain_tidymodels(glm_fold_fit,
                                    data = lsl_folds$splits[[1]] |> 
                                      analysis() |> 
                                      st_drop_geometry() |> 
                                      dplyr::select(slope, cplan, cprof, elev, log10_carea),
                                    y = lsl_folds$splits[[1]] |> 
                                      analysis() |> 
                                      st_drop_geometry() |> 
                                      pull(lslpts)) |>
  model_profile(N = 100, type = "partial")
Created on 2023-05-30 with reprex v2.0.2

topepo commented 7 months ago

Your formula lslpts ~ slope + cplan + cprof + elev + log10_carea has lslpts as the outcome and step_mutate() is using it to construct the case weights.

tidymodels enforces the constraint that the outcome should not be used (in any way) when making predictions. Even if that column is available as prediction-time. This is to eliminate information leakage. It specifically excludes the outcome column(s) during prediction.

I would try using skip = TRUE so that the step does not execute outside of processing the training set.

topepo commented 7 months ago

Here's a smaller reprex:


options(pillar.advice = FALSE, pillar.min_title_chars = Inf)

mtcar_wts <- 
  mtcars %>% 
  mutate(case_wts = hardhat::importance_weights(NA))

car_rec <- 
  recipe(mpg ~ wt + disp + gear, data = mtcar_wts) %>% 
  step_mutate(case_wts = hardhat::importance_weights(1 / mpg), role = "case_weights")

lm_fit <- 
  car_rec %>% 
  workflow(linear_reg()) %>% 
  add_case_weights(case_wts) %>% 

lm_fit %>% 
  extract_fit_engine() %>% 
#> (Intercept)          wt        disp        gear 
#> 34.32324535 -2.69696818 -0.02006457 -0.34745007

predict(lm_fit, mtcar_wts[1:3,])
car_skip_rec <- 
  recipe(mpg ~ wt + disp + gear, data = mtcar_wts) %>% 
  step_mutate(case_wts = hardhat::importance_weights(1 / mpg), 
              role = "case_weights", skip = TRUE)

lm_skip_fit <- 
  car_skip_rec %>% 
  workflow(linear_reg()) %>% 
  add_case_weights(case_wts) %>% 

lm_skip_fit %>% 
  extract_fit_engine() %>% 
#> (Intercept)          wt        disp        gear 
#> 34.32324535 -2.69696818 -0.02006457 -0.34745007

predict(lm_skip_fit, mtcar_wts[1:3,])
#> # A tibble: 3 × 1
#>   .pred
#>   <dbl>
#> 1  22.7
#> 2  22.0
#> 3  24.5

Created on 2024-01-31 with reprex v2.0.2