tidymodels / workflowsets

Create a collection of modeling workflows
https://workflowsets.tidymodels.org/
Other
92 stars 10 forks source link

Way to only remove rows with missing data if that column is needed in each model? #71

Closed gdmcdonald closed 2 years ago

gdmcdonald commented 2 years ago

At the moment I'm specifying around 1k different model formulas on one input dataframe in a workflowset, and I want it to exclude a data row for a model if that row has missing data in one of the columns required for that particular model, so that the model can run, and so that I'm not imputing the missing values. Is there a nice way to do this with a workflow set? Something like an na.action = na.omit option?

library(tidyverse)
library(tidymodels)
library(workflowsets)
library(mlbench)
library(DescTools)
library(doParallel)

#convenience function to insert missing values
miss_vals <- function(x,p=.05) {x[sample(1:length(x), floor(p*length(x)))] <- NA; x}

#dataset is mlbench::BreastCancer with some missing values in Bare.nuclei and Epith.c.size columns
data(BreastCancer)
dataset <- BreastCancer %>%
  dplyr::select(-Id) %>%
  mutate(across(-Class,as.numeric)) %>%
  mutate(across(c(Bare.nuclei, Epith.c.size),miss_vals)) %>% #add missing vals
  na.omit() 

I want to run the code without the na.omit on the last line so that rows are only omitted if the missing value would have been used in the model.

#specify rf model
rf_spec <-
  rand_forest(mtry = tune(), min_n = tune(), trees = 1000) %>%
  set_engine("ranger") %>%
  set_mode("classification")

#test train split
set.seed(1)
trn_tst_split <- initial_split(dataset, strata = Class)
folds <- vfold_cv(training(trn_tst_split), v=5, strata = Class) #5 folds just to run quickly

# function to generate all possible combinations of formulas with n variables added to a base formula
add_n_var_formulas <- function(y_var, x_vars, data, n = 2, include_base = T, include_full = T){

  potential_x_vars <-
    names(data) %>% #variable names
    {.[!. %in% c(y_var, x_vars)]} #remove y and existing x vars

  combos <-
    potential_x_vars %>%
    DescTools::CombSet(n, repl=FALSE, ord=FALSE) %>%
    as_tibble() %>%
    unite(col = combination, sep = " + ") %>%
    pull(combination, name = combination)

  if (include_base) {combos[length(combos)+1]<-"1"}
  if (include_full) {combos[length(combos)+1]<-"."}

  base_pred <- paste(x_vars, collapse = " + ")

  new_formulas <-
    combos %>%
    {paste(y_var, " ~ ",paste(base_pred,., sep = " + "))} %>%  #make formula strings
    purrr::map(as.formula) %>% #make formula
    purrr::map(workflowsets:::rm_formula_env)

  names(new_formulas) <- combos

  if(include_base){ names(new_formulas)[names(new_formulas) == "1"] <- "Base Model" }
  if(include_full){names(new_formulas)[names(new_formulas) == "."] <- "All Predictors"}

  new_formulas
}

#make 10 possible formulas to predict `Class` from `Normal.nucleoli` + (something else)
formulas <- add_n_var_formulas(
  y_var = "Class",
  x_vars = c("Normal.nucleoli"),
  data = dataset,
  n=1)

The formulas which I am testing in this example are

Screen Shot 2022-02-15 at 12 35 15 pm

(but there are thousands of formulas in my real data set)

# Create workflow set
cancer_workflows <-
  workflow_set(
    preproc = formulas,
    models = list(rf = rf_spec)
  )

# Create hyperparameter tuning grid
grid_ctrl <-
  control_grid(
    save_pred = TRUE,
    parallel_over = "everything",
    save_workflow = TRUE
  )

#Set up for parallel processing
all_cores <- parallel::detectCores(logical = TRUE)
cl <- makePSOCKcluster(all_cores)
registerDoParallel(cl)

# Fit models
cancer_workflows <-
  cancer_workflows %>%
  workflow_map("tune_grid",
               seed = 1503,
               grid = 25,
               control = grid_ctrl,
               resamples = folds,
               verbose = TRUE)

stopImplicitCluster()

# Plot results
cancer_workflows %>%
  collect_metrics(summarize = FALSE) %>%
  filter(.metric == "roc_auc") %>%
  group_by(wflow_id, model) %>%
  dplyr::summarize(
    ROC = mean(.estimate),
    lower = quantile(.estimate,probs = 0.2),
    upper = quantile(.estimate,probs = 0.8),
    .groups = "drop"
  ) %>%
  mutate(wflow_id = factor(wflow_id),
         wflow_id = reorder(wflow_id, ROC)) %>%
  ggplot(aes(x = ROC, y = wflow_id)) +
  geom_point() +
  geom_errorbar(aes(xmin = lower, xmax = upper), width = .25) +
  labs(title = "Comparing models of the form `Class ~ Normal.nucleoli + ...`",y = "Additional variable")

Which should give a nice comparison of what adding each variable adds to the ROC AUC

plot_zoom_png

mattwarkentin commented 2 years ago

Hi @gdmcdonald,

As I mentioned over in #70, I still think {recipes} is the way to go. In this case, a very simple recipe that only removes missing data should do the trick. Basically, the example from above is the exact same (except you can remove the na.omit() from where you define dataset) right up to where you define formulas and then I added this:

# Create a recipe from each formula
recipes <-
  map(formulas, function(form) {
    recipe(form, data = training(trn_tst_split)) %>% 
      step_naomit(all_predictors())
  })

# Create workflow set
cancer_workflows <-
  workflow_set(
    preproc = recipes,
    models = list(rf = rf_spec)
  )

During the training loop, each time the recipe is trained it will remove missing data from the analysis() portion of the CV fold depending on which predictors are in the model formula and the patterns of missing data.

gdmcdonald commented 2 years ago

Thank you @mattwarkentin, that's a great direction to go in. Unfortunately still all the models with missing data fail, (with or without skip = TRUE in step_naomit()) as the missing rows are still in the cv folds and the testing set which seems to break it. Do I need to somehow map different cv folds to each model as well?

gdmcdonald commented 2 years ago

On further searching I think this error is the culprit https://github.com/tidymodels/tune/issues/181 https://github.com/imbs-hl/ranger/issues/94

Would love to know how I can write a wrapper around ranger so that it handles na values in the normal R way to fix the problem, but any tips on working around the issue in the meantime would be appreciated.

juliasilge commented 2 years ago

If you use step_naomit(all_predictors(), skip = TRUE) that will remove NA values for the observations you are using for training. I believe the problem at this point is that you end up with missing data in the observations you want to predict on. I get errors like (notice (predictions)):

preprocessor 1/1, model 1/24 (predictions): Error: Missing data in columns: Bare.nuclei.

You've probably noticed the info/advice we give on recipe steps that involve removing rows, including step_naomit():

This step can entirely remove observations (rows of data), which can have unintended and/or problematic consequences when applying the step to new data later via [bake()]. Consider whether skip = TRUE or skip = FALSE is more appropriate in any given use case. In most instances that affect the rows of the data being predicted, this step probably should not be applied at all; instead, execute operations like this outside and before starting a preprocessing [recipe()].

This behavior is for sure by design and a safety consideration. It is a pretty strong assumption of the workflowsets package that you are evaluating the different model configurations on the same data; I don't think we're going to want to support a general option for tuning where it is easy to end up with different datasets for different model configurations.

That being said, you are the boss of your own model evaluation, of course! If I really wanted to do this myself, I would probably create a function that took the training set and a set of predictors as arguments and returned a set of metrics. I would remove NA values in this function before I created the resampling folds. I would then loop/purrr/apply through the sets of predictors to evaluate. For now, this is what we recommend you do in this situation.

juliasilge commented 2 years ago

Let us know if you have further questions! 🙌

github-actions[bot] commented 2 years ago

This issue has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue.