tidymodels / probably

Tools for post-processing class probability estimates
https://probably.tidymodels.org/
Other
115 stars 15 forks source link

Generalize `int_conformal_quantile()` #131

Open brshallo opened 7 months ago

brshallo commented 7 months ago

Currently int_conformal_quantile() seems limited in that it:

Ideally, for any model/workflow fit that is set-up to output quantiles (or intervals), int_conformal_quantile() would simply use the calibration data (or the available held-out data like int_conformal_cv() does if set-up resamples) to adjust the quantiles outputted by the fitted workflow.

As described from 31:00 to 37:00 by Angelopuoulos and Bates here: https://www.youtube.com/watch?v=nql000Lu_iE&list=PLXs7Va5fWFZ72DTVcx4qIvny1xNrl68PK&index=1), the steps then would be: (with parsnip / workflows) train an arbitrary model (that is capable of optimizing on pinball loss function / outputting quantiles / intervals) --> pass the resulting object into (a generalized version of) int_conformal_quantile() whose responsibility it would be to calibrate the quantiles from the model/workflow (which would be a similar set-up but that just doesn't have the probably:::quant_train() step so more similar to how the other int_conformal_*() functions work) --> which could then be used to produce well-calibrated intervals on new data.

I imagine this would be dependent on integrated support in parsnip for quantiles (https://github.com/tidymodels/parsnip/issues/119, https://github.com/tidymodels/parsnip/issues/465). Figured may as well open an issue though.

Rough ex with a ranger workflow:

### Set-up a ranger workflow (that has a recipe) for quantile regression forests

library(tidyverse)
library(tidymodels)
library(AmesHousing)

ames <- make_ames() %>% 
  mutate(Years_Old = Year_Sold - Year_Built,
         Years_Old = ifelse(Years_Old < 0, 0, Years_Old))

set.seed(4595)
data_split <- initial_split(ames, strata = "Sale_Price", prob = 0.6)

ames_train <- training(data_split)
ames_holdout  <- testing(data_split) 

rf_recipe <- 
  recipe(
    Sale_Price ~ Lot_Area + Neighborhood  + Years_Old + Gr_Liv_Area + Overall_Qual + Total_Bsmt_SF + Garage_Area, 
    data = ames_train
  ) %>%
  step_log(Sale_Price, base = 10) %>%
  step_other(Neighborhood, Overall_Qual, threshold = 50) %>% 
  step_novel(Neighborhood, Overall_Qual) %>% 
  step_dummy(Neighborhood, Overall_Qual) 

rf_mod <- rand_forest() %>%
  set_engine("ranger", importance = "impurity", seed = 63233, quantreg = TRUE) %>%
  set_mode("regression")

set.seed(63233)
rf_wf <- workflows::workflow() %>% 
  add_model(rf_mod) %>% 
  add_recipe(rf_recipe) %>% 
  fit(ames_train)

#### Conformalize quantiles from model/workflow

set.seed(1243)
val_new_split <- initial_split(ames_holdout, prop = 0.50)
ames_cal <- training(val_new_split)
ames_new <- testing(val_new_split)

quant_predict <- function(fit, new_data, level) {
  alpha <- (1 - level)
  quant_pred <- predict(fit, new_data, type = "quantiles", quantiles = c(alpha / 2, 1 - (alpha / 2)))
  quant_pred <- dplyr::as_tibble(quant_pred)
  quant_pred <- stats::setNames(quant_pred, c(".pred_lower", ".pred_upper"))
  quant_pred
}

level <- 0.90
cal_data_baked <- workflows::extract_recipe(rf_wf) %>% bake(ames_cal)
new_data_baked <- workflows::extract_recipe(rf_wf) %>% bake(ames_new)

preds_q <- bind_cols(
  select(cal_data_baked, Sale_Price),
  quant_predict(rf_wf$fit$fit$fit, cal_data_baked, level = level)
)

resid <- preds_q %>% 
  mutate(R_low = .pred_lower - Sale_Price,
         R_high = Sale_Price - .pred_upper) %>% 
  with(pmax(R_low, R_high))

q_hat <- quantile(resid, probs = level)

preds_q_new <- bind_cols(
  select(new_data_baked, Sale_Price),
  quant_predict(rf_wf$fit$fit$fit, new_data_baked, level = 0.90)
) %>% 
  mutate(.pred_lower = .pred_lower - q_hat,
         .pred_upper = .pred_upper + q_hat)

preds_q_new %>% 
  summarise(coverage = mean(Sale_Price <= .pred_upper & Sale_Price >= .pred_lower))

The adaptability of the intervals then would be coming from the model in the workflow being able to output quantiles / intervals (rather than from overriding the workflow and retraining for the interval). Even if the underlying workflow isn't that adaptive (e.g. say the user has a workflow for an lm model that is just returning standard prediction intervals based on variance) the approach described above would likely do a slightly better job at factoring in the epistemic uncertainty in the model estimation compared to just doing int_conformal_split(), because it would allow for wider intervals further from the data centroid, which doesn't happen with int_conformal_split().