ModelOriented / ingredients

Effects and Importances of Model Ingredients
https://modeloriented.github.io/ingredients/
GNU General Public License v3.0
37 stars 18 forks source link

`calculate_variable_profile` coerces `integer`s to `numeric`s #145

Closed simonpcouch closed 1 year ago

simonpcouch commented 1 year ago

The tidymodels team recently introduced support for finer-grained numeric classes in recipes. A user recently pointed on our community forums that this introduced issues with model_profile() in some cases. Here's a reprex:

library(tidymodels)
library(DALEXtra)
#> Loading required package: DALEX
#> Welcome to DALEX (version: 2.4.2).
#> Find examples and detailed introduction at: http://ema.drwhy.ai/
#> Additional features will be available after installation of: ggpubr.
#> Use 'install_dependencies()' to get all suggested dependencies
#> 
#> Attaching package: 'DALEX'
#> The following object is masked from 'package:dplyr':
#> 
#>     explain

ames_split <- initial_split(ames)
ames_train <- training(ames_split)

vip_features <- c("Neighborhood", "Gr_Liv_Area", "Year_Built", 
                  "Bldg_Type", "Latitude", "Longitude")

vip_train <- 
  ames_train %>% 
  select(all_of(vip_features))

rf_model <- 
  rand_forest(trees = 1000) %>% 
  set_engine("ranger") %>% 
  set_mode("regression")

rf_wflow <- 
  workflow() %>% 
  add_formula(
    Sale_Price ~ Neighborhood + Gr_Liv_Area + Year_Built + Bldg_Type + 
      Latitude + Longitude) %>% 
  add_model(rf_model) 

rf_fit <- fit(rf_wflow, ames_train)

explainer_rf <- 
  explain_tidymodels(
    rf_fit, 
    data = vip_train, 
    y = ames_train$Sale_Price,
    label = "random forest",
    verbose = FALSE
  )

model_profile(explainer_rf, N = 500, variables = "Year_Built")
#> Error in `scream()`:
#> ! Can't convert from `data$Year_Built` <double> to `Year_Built` <integer> due to loss of precision.
#> • Locations: 3, 13, 23, 72, 76, 86, 96, 145, 149, 159, 169, 218, 222, 232, 242,...

#> Backtrace:
#>      ▆
#>   1. └─DALEX::model_profile(explainer_rf, N = 500, variables = "Year_Built")
#>   2.   ├─ingredients::ceteris_paribus(...)
#>   3.   └─ingredients:::ceteris_paribus.explainer(...)
#>   4.     └─ingredients:::ceteris_paribus.default(...)
#>   5.       ├─ingredients:::calculate_variable_profile(...)
#>   6.       └─ingredients:::calculate_variable_profile.default(...)
#>   7.         └─base::lapply(...)
#>   8.           └─ingredients (local) FUN(X[[i]], ...)
#>   9.             ├─DALEX (local) predict_function(model, new_data, ...)
#>  10.             └─DALEXtra:::yhat.workflow(model, new_data, ...)
#>  11.               ├─stats::predict(X.model, newdata)
#>  12.               └─workflows:::predict.workflow(X.model, newdata)
#>  13.                 └─workflows:::forge_predictors(new_data, workflow)
#>  14.                   ├─hardhat::forge(new_data, blueprint = mold$blueprint)
#>  15.                   └─hardhat:::forge.data.frame(new_data, blueprint = mold$blueprint)
#>  16.                     ├─hardhat::run_forge(blueprint, new_data = new_data, outcomes = outcomes)
#>  17.                     └─hardhat:::run_forge.default_formula_blueprint(...)
#>  18.                       └─hardhat:::forge_formula_default_clean(...)
#>  19.                         └─hardhat::scream(predictors, blueprint$ptypes$predictors, allow_novel_levels = blueprint$allow_novel_levels)

Created on 2022-12-05 with reprex v2.0.2

The issue arises here, where the numeric split_points are dropped into the (possibly) integer variable:

https://github.com/ModelOriented/ingredients/blob/a44ad390cf07c4a9d520ce1213e4c57ae9164586/R/calculate_variable_profile.R#L39

hbaniecki commented 1 year ago

Hi Simon, thanks for this report.

A quick workaround is to state variable_splits explicitly:

# ok
ingredients::ceteris_paribus(
  explainer_rf, 
  explainer_rf$data,
  variable_splits = list(Year_Built=unique(vip_train$Year_Built))
)

An error occurs due to the default calculate_variable_split()

# error
ingredients::ceteris_paribus(
  explainer_rf, 
  explainer_rf$data,
  variable_splits = ingredients:::calculate_variable_split.default(explainer_rf$data, variables=c("Year_Built"))
)

# float, not an integer
ingredients:::calculate_variable_split.default(explainer_rf$data, variables=c("Year_Built"))

Fixing this issue requires adding !is.integer(selected_column) to https://github.com/ModelOriented/ingredients/blob/a44ad390cf07c4a9d520ce1213e4c57ae9164586/R/calculate_variable_profile.R#L85 which would lead to treating integer features like categorical features with unique().

# ok
ingredients::ceteris_paribus(
  explainer_rf, 
  explainer_rf$data,
  variable_splits = list(Year_Built=unique(vip_train$Year_Built))
)

@pbiecek what do you think?

pbiecek commented 1 year ago

Thanks for tracking down this tricky error!

Treating an integer as a categorical variable is a good idea, as long as it doesn't have too many different levels (e.g. someone has a column with an ids and there are 10000 different values in it, that would kill our profile calculation). So maybe an extra condition in the if statement - if there is an integer variable and the number of different values is under 100 then treated as categorical (i.e. unique) but if there are a lot of values it is converted to float?

hbaniecki commented 1 year ago

I implemented the fix, and it actually still fails ungracefully in the above scenario, because there are 113 unique year values.

This got me thinking that with categorical variables, we don't have a threshold on how many unique values there should be.

We can either:

  1. Remove the auxiliary threshold and all integer variables will be treated as categorical. This removes the error, and users need to pay attention to the results / why it computes for so long.
  2. Set a threshold to the value of grid_points (=101 by default):
    1. only on integer variables. This will lead to the same uninformative error for the user.
    2. on both: integer and categorical variables, and then raise an informative error message for the user to increase grid_points when the threshold is reached. This breaks some previous code but improves the quality of the user's experience interacting with our API.
pbiecek commented 1 year ago

great idea, Let's do 1 with additional warning if there is more than 201 unique values

hbaniecki commented 1 year ago

This is hopefully fixed now on CRAN

image