tidymodels / hardhat

Construct Modeling Packages
https://hardhat.tidymodels.org
Other
101 stars 16 forks source link

forge fails with non-standard roles #197

Closed topepo closed 2 years ago

topepo commented 2 years ago

Example from this Community thread.

library(tidymodels)
library(survival)

colon$sex <- ifelse(colon$sex==1,"male","female")
colon$obstruct <- ifelse(colon$obstruct ==1,"yes","no")
colon$perfor <- ifelse(colon$perfor ==1,"yes","no")
colon$adhere <- ifelse(colon$adhere ==1,"yes","no")
colon$status <- ifelse(colon$status ==1,"death","alive")
colon$node4 <- ifelse(colon$node4 ==1,"yes","no")

colon <- select(colon,id,age,rx,sex,age,obstruct,perfor,adhere,nodes,status)
colon <- na.omit(colon)

data_split<- initial_split(colon,
                           prop = 3/4,
                           strata = status)

train_data <- training(data_split)
test_data<- testing(data_split)

train_rec <-
  recipe(status ~., data = train_data) %>%
  update_role(id, new_role = "ID")%>%
  step_zv(all_numeric(),-all_outcomes()) %>%
  step_normalize(all_numeric(),-all_outcomes())%>%
  step_novel(all_nominal(),-all_outcomes()) %>%
  step_dummy(all_nominal(),-all_outcomes())

set.seed(100)

cv_folds <-
  vfold_cv(train_data,
           v = 5,
           strata = status)

log_spec <- # your model specification
  logistic_reg() %>% # model type
  set_engine(engine = "glm") %>% # model engine
  set_mode("classification") # model mode

log_wflow <- # new workflow object
  workflow() %>% # use workflow function
  add_recipe(train_rec) %>% # use the new recipe
  add_model(log_spec) # add your model spec

log_res <-
  log_wflow %>%
  fit_resamples(
    resamples = cv_folds,
    metrics = metric_set(
      precision, f_meas,
      accuracy, kap,
      roc_auc, sens, spec),
    control = control_resamples(
      save_pred = TRUE)
  )
#> x Fold1: preprocessor 1/1, model 1/1 (predictions): Error in `vectbl_as_col_locat...
#> x Fold2: preprocessor 1/1, model 1/1 (predictions): Error in `vectbl_as_col_locat...
#> x Fold3: preprocessor 1/1, model 1/1 (predictions): Error in `vectbl_as_col_locat...
#> x Fold4: preprocessor 1/1, model 1/1 (predictions): Error in `vectbl_as_col_locat...
#> x Fold5: preprocessor 1/1, model 1/1 (predictions): Error in `vectbl_as_col_locat...
#> Warning: All models failed. See the `.notes` column.

log_res %>% 
  collect_notes() %>% 
  pluck("note") %>% 
  cat(sep = "\n")
#> Error in `vectbl_as_col_location()`:
#> ! Can't subset columns that don't exist.
#> ✖ Column `id` doesn't exist.
#> Error in `vectbl_as_col_location()`:
#> ! Can't subset columns that don't exist.
#> ✖ Column `id` doesn't exist.
#> Error in `vectbl_as_col_location()`:
#> ! Can't subset columns that don't exist.
#> ✖ Column `id` doesn't exist.
#> Error in `vectbl_as_col_location()`:
#> ! Can't subset columns that don't exist.
#> ✖ Column `id` doesn't exist.
#> Error in `vectbl_as_col_location()`:
#> ! Can't subset columns that don't exist.
#> ✖ Column `id` doesn't exist.

Created on 2022-05-18 by the reprex package (v2.0.1)

The error is in tune::: forge_from_workflow() where

Browse[2]> str(new_data)
'data.frame':   274 obs. of  9 variables:
 $ id      : num  9 16 20 21 38 46 49 59 67 78 ...
 $ age     : num  46 68 50 64 64 63 59 65 49 63 ...
 $ rx      : Factor w/ 3 levels "Obs","Lev","Lev+5FU": 2 1 2 1 1 1 3 1 3 3 ...
 $ sex     : chr  "male" "male" "female" "male" ...
 $ obstruct: chr  "no" "no" "no" "yes" ...
 $ perfor  : chr  "no" "no" "no" "no" ...
 $ adhere  : chr  "yes" "no" "yes" "no" ...
 $ nodes   : num  2 1 1 1 1 2 6 6 6 2 ...
 $ status  : chr  "alive" "alive" "alive" "alive" ...
 - attr(*, "na.action")= 'omit' Named int [1:36] 187 188 197 198 285 286 377 378 397 398 ...
  ..- attr(*, "names")= chr [1:36] "187" "188" "197" "198" ...

I think that the non-predictor columns should be in new_data (so they show up in the saved predictions) but

forged <- hardhat::forge(new_data, blueprint, outcomes = TRUE)

fails since id there.

I think that we would want forge() to let those be. The model prediction shouldn't fail if they are there since parsnip removes any columns that are not required.

EmilHvitfeldt commented 2 years ago

If you are using the dev version you need to set the bake_dependant_roles argument in the recipes blueprint to make sure that the ID variables are passed through. https://hardhat.tidymodels.org/dev/reference/default_recipe_blueprint.html#arguments

with

id_blueprint <- hardhat::default_recipe_blueprint(bake_dependent_roles = "ID")
library(tidymodels)
library(survival)

colon$sex <- ifelse(colon$sex==1,"male","female")
colon$obstruct <- ifelse(colon$obstruct ==1,"yes","no")
colon$perfor <- ifelse(colon$perfor ==1,"yes","no")
colon$adhere <- ifelse(colon$adhere ==1,"yes","no")
colon$status <- ifelse(colon$status ==1,"death","alive")
colon$node4 <- ifelse(colon$node4 ==1,"yes","no")

colon <- select(colon,id,age,rx,sex,age,obstruct,perfor,adhere,nodes,status)
colon <- na.omit(colon)

data_split<- initial_split(colon,
                           prop = 3/4,
                           strata = status)

train_data <- training(data_split)
test_data<- testing(data_split)

train_rec <-
  recipe(status ~., data = train_data) %>%
  update_role(id, new_role = "ID")%>%
  step_zv(all_numeric(),-all_outcomes()) %>%
  step_normalize(all_numeric(),-all_outcomes())%>%
  step_novel(all_nominal(),-all_outcomes()) %>%
  step_dummy(all_nominal(),-all_outcomes())

set.seed(100)

cv_folds <-
  vfold_cv(train_data,
           v = 5,
           strata = status)

log_spec <- # your model specification
  logistic_reg() %>% # model type
  set_engine(engine = "glm") %>% # model engine
  set_mode("classification") # model mode

id_blueprint <- hardhat::default_recipe_blueprint(bake_dependent_roles = "ID")

log_wflow <- # new workflow object
  workflow() %>% # use workflow function
  add_recipe(train_rec, blueprint = id_blueprint) %>% # use the new recipe
  add_model(log_spec) # add your model spec

log_res <-
  log_wflow %>%
  fit_resamples(
    resamples = cv_folds,
    metrics = metric_set(
      precision, f_meas,
      accuracy, kap,
      roc_auc, sens, spec),
    control = control_resamples(
      save_pred = TRUE)
  )
#> ! Fold1: preprocessor 1/1, model 1/1 (predictions): prediction from a rank-defici...
#> ! Fold2: preprocessor 1/1, model 1/1 (predictions): prediction from a rank-defici...
#> ! Fold3: preprocessor 1/1, model 1/1 (predictions): prediction from a rank-defici...
#> ! Fold4: preprocessor 1/1, model 1/1 (predictions): prediction from a rank-defici...
#> ! Fold5: preprocessor 1/1, model 1/1 (predictions): prediction from a rank-defici...

log_res %>% 
  collect_notes() %>% 
  pluck("note") %>% 
  cat(sep = "\n")
#> prediction from a rank-deficient fit may be misleading
#> prediction from a rank-deficient fit may be misleading
#> prediction from a rank-deficient fit may be misleading
#> prediction from a rank-deficient fit may be misleading
#> prediction from a rank-deficient fit may be misleading

Created on 2022-05-18 by the reprex package (v2.0.1)

DavisVaughan commented 2 years ago

This is exactly as expected. This is how it should have worked from the beginning and is required for case weights to work. The only role that forge() looks for by default is the "predictor" role.

If you really require the "ID" role in the recipe then you should use bake_dependent_roles as Emil suggested.

But I think that the user doesn't actually want id to be involved in the recipe at all, so they should have used all_numeric_predictors() and all_nominal_predictors(), like this:

train_rec <-
  recipe(status ~., data = train_data) %>%
  update_role(id, new_role = "ID")%>%
  step_zv(all_numeric_predictors()) %>%
  step_normalize(all_numeric_predictors())%>%
  step_novel(all_nominal_predictors()) %>%
  step_dummy(all_nominal_predictors())

Then it works as expected

EmilHvitfeldt commented 2 years ago

Yes, -all_outcomes() is definitely an anti-pattern.

github-actions[bot] commented 2 years ago

This issue has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue.