tidymodels / stacks

An R package for tidy stacked ensemble modeling
https://stacks.tidymodels.org
Other
295 stars 27 forks source link

Implement importance weighting in stacked ensembles #233

Open mark-burdon opened 1 day ago

mark-burdon commented 1 day ago

Feature

Similar to https://github.com/tidymodels/probably/issues/159.

When producing a stacked ensemble of predictions, although the base models may have been trained using importance weights, these weights don't seem to be passed on to the stacking procedure. In my use case I use importance weights to give additional weight to the minority class in a highly imbalanced dataset.

Here's a reprex - not a perfect test (comparing two "best" base learners with a stacked ensemble of several candidates) but demonstrates that the weights aren't being passed on and therefore the stacked ensemble is much less likely to predict the (highly weighted) minority class, compared to the base learners:

library(tidyverse)
library(tidymodels)
set.seed(100)

# Create noisy, imbalanced data, add weight column giving roughly equal overall weight
df <- caret::twoClassSim(n=1000, intercept = -12, linearVars = 2, noiseVars = 5,corrVars = 3, corrType = "AR1") |>
  dplyr::mutate(weights = dplyr::if_else(Class == "Class1",
                                         true =  0.15,
                                         false = 0.85)) |>
  dplyr::mutate(weights = hardhat::importance_weights(weights),
                Class = case_match(Class, "Class1" ~ "Majority",
                                   "Class2" ~ "Minority"))

# Create recipe and logistic regression specification
glm_recipe <- recipes::recipe(x = df, formula = Class ~ .)

glm_spec <- parsnip::logistic_reg(mode = "classification",
                                  engine = "glm")

# Combine into workflow
glm_wf <- workflows::workflow(preprocessor = glm_recipe,
                              spec = glm_spec) |>
  workflows::add_case_weights(col = weights)

# Now let's do an xgBoost one
# Don't include all variables so it's not amazingly accurate
xgb_recipe <-  recipes::recipe(x = df,
                               formula = Class ~ Linear1 + Linear2 + Nonlinear1 + Noise1 + Noise2 + weights)

xgb_spec <- parsnip::boost_tree(mode = "classification",
                                engine = "xgboost",
                                trees = 100,
                                tree_depth = tune(),
                                learn_rate = 0.1)

xgb_wf <- workflows::workflow(preprocessor = xgb_recipe,
                              spec = xgb_spec) |>
  workflows::add_case_weights(col = weights)

# Create resamples for model fitting
resamples <- rsample::vfold_cv(data = df,
                               v = 5,
                               strata = Class)

# Cross-validate the models
glm_cv <- fit_resamples(glm_wf, resamples = resamples, control = control_stack_resamples())
xgb_cv <- tune_grid(xgb_wf, resamples = resamples, control = control_stack_resamples())

# Stack the ensemble
model_stack <- stacks() |>
  add_candidates(candidates = glm_cv) |>
  add_candidates(candidates = xgb_cv) |>
  blend_predictions() |>
  fit_members()

# Add stacked predictions to dataframe
df <- df |>
  bind_cols(predict(model_stack, df, type = "prob")|>
              rename(.pred_Majority_stack =.pred_Majority,
                     .pred_Minority_stack =.pred_Minority))

# Fit best individual models
best_glm <- fit_best(glm_cv)
best_xgb <- fit_best(xgb_cv)

# Append individual model predictions to the data
df <- df |>
  bind_cols(predict(best_glm, df, type = "prob") |>
                    rename(.pred_Majority_glm =.pred_Majority,
                           .pred_Minority_glm =.pred_Minority)) |>
  bind_cols(predict(best_xgb, df, type = "prob") |>
              rename(.pred_Majority_xgb =.pred_Majority,
                     .pred_Minority_xgb =.pred_Minority))

# Bring together and visualise the predictions
 df |>
  select(Class, starts_with(".pred_Min")) |>
  pivot_longer(starts_with(".pred_Min"),
               names_to = "model",
               values_to = ".pred_min",
               names_prefix = ".pred_Minority_") |>
  ggplot(aes(x = .pred_min,colour= model,  fill = model)) + 
  geom_density(alpha = 0.2) +
  theme_bw() +
  labs(x = "Predicted probability of being in the minority class") +
   facet_wrap(~Class, nrow = 2)
simonpcouch commented 1 day ago

Ah, heard! This is worth looking into.