ModelOriented / DALEXtra

Extensions for the DALEX package
https://ModelOriented.github.io/DALEXtra/
65 stars 10 forks source link

Support for `stacks` #68

Closed simonschoe closed 2 years ago

simonschoe commented 2 years ago

Hi there,

I am wondering if you could provide support for using DALEX in conjunction with stacks from the tidymodels ecosystem to enable model explanations for model stacks?

Best, Simon

topepo commented 2 years ago

Please contact myself and @simonpcouch if there is anything that we can do to help.

Jeffrothschild commented 2 years ago

I would also find it incredibly useful ti be able to use DALEXtra with model stacks. Thank you!

maksymiuks commented 2 years ago

Thank you for that one, I'll try o provide this functionality over the course of this week and provide you for testing

Jeffrothschild commented 2 years ago

Hi, has there been any update on this? Thank you

maksymiuks commented 2 years ago

@Jeffrothschild @topepo

I think the support should be already there, please take a look at this example:

library("DALEX")
library("dplyr")
colnames(fifa)
fifa_small <- fifa |>
  select(value_eur, age, 
         attacking_crossing:attacking_volleys, 
         defending_marking:defending_sliding_tackle)

library("tidymodels")
library("recipes")
rec_pca <- recipe(value_eur ~ ., data = fifa_small) |>
  step_cut(age, breaks = c(20, 25, 30)) |>
  step_dummy(age) |>
  step_pca(starts_with("attacking"), num_comp = 1, prefix = "attacking") |>
  step_pca(starts_with("defending"), num_comp = 1, prefix = "defending") 

model <- boost_tree(trees = 100, tree_depth = 3) |>
  set_engine("xgboost") |>
  set_mode("regression")

fifa_small_pca <- rec_pca |> prep() |> bake(fifa_small) |> as.data.frame()

model_pca_raw <- workflow() |>
  add_formula(value_eur ~ .) |>
  add_model(model)  |>
  fit(data = fifa_small_pca)

fifa_ex_raw <- DALEXtra::explain_tidymodels(model_pca_raw, 
                                            data = fifa_small_pca[,-1], 
                                            y = fifa_small_pca$value_eur,
                                            label = "Tidy-Boosting-Raw") 

model_performance(fifa_ex_raw) |> plot(geom = "histogram")
model_parts(fifa_ex_raw) |> plot()
model_profile(fifa_ex_raw) |> plot()
model_diagnostics(fifa_ex_raw) |> plot(variable = "y", yvariable = "y_hat")

lewandowski <- fifa_small_pca[21, ]

predict_parts(fifa_ex_raw, lewandowski) |> plot()
predict_profile(fifa_ex_raw, lewandowski) |> plot()
predict_diagnostics(fifa_ex_raw, lewandowski) |> plot()

Meaning DALEXtra works for the workflow class. Is there any use-case where the class of the model is different and the above example fails?

Jeffrothschild commented 2 years ago

I think the issue is when using model stacks to combine models. Max could certainly explain this better, but here is an example...

library("DALEX")
library("dplyr")
colnames(fifa)
fifa_small <- fifa |>
  select(value_eur, age, 
         attacking_crossing:attacking_volleys, 
         defending_marking:defending_sliding_tackle)

library("tidymodels")
library("recipes")

fifa_small_folds <- vfold_cv(fifa_small, v = 2, repeats = 1)
fifa_small_folds

rec_pca <- recipe(value_eur ~ ., data = fifa_small) |>
  step_cut(age, breaks = c(20, 25, 30)) |>
  step_dummy(age) |>
  step_pca(starts_with("attacking"), num_comp = 1, prefix = "attacking") |>
  step_pca(starts_with("defending"), num_comp = 1, prefix = "defending") 

model <- boost_tree(trees = 100, tree_depth = 3) |>
  set_engine("xgboost") |>
  set_mode("regression")

model2 <-
  linear_reg(penalty = 0.1, mixture = 1) %>%
  set_engine('glmnet')

wfs <- 
  workflow_set(
    preproc = list(rec_pca),  
    models = list(model, model2), 
    cross = T  )

wfs

wfs_rs <-
  workflow_map(
    wfs,
    "fit_resamples",
    resamples = fifa_small_folds,
    control = control_grid(save_pred = TRUE,
                           parallel_over = "everything",
                           save_workflow = TRUE            
  ) )

wfs_rs

library(stacks)
tidymodels_prefer()

wfs_stack <- 
  stacks() %>% 
  add_candidates(wfs_rs)

blend_ens <- blend_predictions(wfs_stack, penalty = 10^seq(-2, 0, length = 10))
blend_ens

ens_fit <- fit_members(blend_ens)
ens_fit

The question is what would go into the explainer function?

maksymiuks commented 2 years ago

@Jeffrothschild Thanks for the update

The general rule is that explain needs object that predict-like function can be called on. It does not necessarily have to be a predict generic, any function that would extract predictions for new data from the model would work. Therefore, based on your example, ens_fit should go to the explain.

That being said, now I see where is the problem. explain_tidymodels assumes workflow metadata structure and checks the trained field that does not exist for model_stack. I'll address this one and get back to you with a development version

maksymiuks commented 2 years ago

@Jeffrothschild Hi!

Requested changes are subject to this PR https://github.com/ModelOriented/DALEXtra/pull/77 and will be merged to master once automated checks pass. I'd love to hear your opinion if you could test it.

Jeffrothschild commented 2 years ago

@maksymiuks this looks fantastic, thank you!!!!

I will stress test it a bit more this week, but I really appreciate you making these updates.

@topepo may be a better person to test it out as well

simonpcouch commented 2 years ago

Will give this a look tomorrow or Tuesday! Thanks!

simonpcouch commented 2 years ago

Thumbs up from our end!

Only note here is that the default referenced here: https://github.com/ModelOriented/DALEXtra/blob/53e643b9df14b10a96f34b7d57c9458ea833597a/R/explain_tidymodels.R#L16

results in a somewhat uninformative default Model label. From the tests:

> explain_tidymodels(ens_fit, data = titanic_imputed, y = titanic_imputed$survived, verbose = FALSE)
Model label:  list 
Model class:  linear_stack,model_stack,list 
Data head  :
  gender age class    embarked  fare sibsp parch survived
1   male  42   3rd Southampton  7.11     0     0        0
2   male  13   3rd Southampton 20.05     0     2        0

Thanks for yall's work here. :)

maksymiuks commented 2 years ago

Thank you all for the support!

I'll investigate the issue with label and then send the newest version to CRAN

maksymiuks commented 2 years ago

@simonpcouch Regarding the weird default label, it's actually related to how DALEX handles default values. It takes the last value of the class object resulting in a list for that case. It happens there https://github.com/ModelOriented/DALEX/blob/master/R/explain.R#L145