Support for `stacks` #68

Closed simonschoe closed 2 years ago

simonschoe commented 2 years ago

Hi there,

I am wondering if you could provide support for using DALEX in conjunction with stacks from the tidymodels ecosystem to enable model explanations for model stacks?

Best, Simon

topepo commented 2 years ago

Please contact myself and @simonpcouch if there is anything that we can do to help.

Jeffrothschild commented 2 years ago

I would also find it incredibly useful ti be able to use DALEXtra with model stacks. Thank you!

maksymiuks commented 2 years ago

Thank you for that one, I'll try o provide this functionality over the course of this week and provide you for testing

Jeffrothschild commented 2 years ago

Hi, has there been any update on this? Thank you

maksymiuks commented 2 years ago

@Jeffrothschild @topepo

I think the support should be already there, please take a look at this example:

fifa_small <- fifa |>
  select(value_eur, age, 

rec_pca <- recipe(value_eur ~ ., data = fifa_small) |>
  step_cut(age, breaks = c(20, 25, 30)) |>
  step_dummy(age) |>
  step_pca(starts_with("attacking"), num_comp = 1, prefix = "attacking") |>
  step_pca(starts_with("defending"), num_comp = 1, prefix = "defending") 

model <- boost_tree(trees = 100, tree_depth = 3) |>
  set_engine("xgboost") |>

fifa_small_pca <- rec_pca |> prep() |> bake(fifa_small) |>

model_pca_raw <- workflow() |>
  add_formula(value_eur ~ .) |>
  add_model(model)  |>
  fit(data = fifa_small_pca)

fifa_ex_raw <- DALEXtra::explain_tidymodels(model_pca_raw, 
                                            data = fifa_small_pca[,-1], 
                                            y = fifa_small_pca$value_eur,
                                            label = "Tidy-Boosting-Raw") 

model_performance(fifa_ex_raw) |> plot(geom = "histogram")
model_parts(fifa_ex_raw) |> plot()
model_profile(fifa_ex_raw) |> plot()
model_diagnostics(fifa_ex_raw) |> plot(variable = "y", yvariable = "y_hat")

lewandowski <- fifa_small_pca[21, ]

predict_parts(fifa_ex_raw, lewandowski) |> plot()
predict_profile(fifa_ex_raw, lewandowski) |> plot()
predict_diagnostics(fifa_ex_raw, lewandowski) |> plot()

Meaning DALEXtra works for the workflow class. Is there any use-case where the class of the model is different and the above example fails?

Jeffrothschild commented 2 years ago

I think the issue is when using model stacks to combine models. Max could certainly explain this better, but here is an example...

fifa_small <- fifa |>
  select(value_eur, age, 


fifa_small_folds <- vfold_cv(fifa_small, v = 2, repeats = 1)

rec_pca <- recipe(value_eur ~ ., data = fifa_small) |>
  step_cut(age, breaks = c(20, 25, 30)) |>
  step_dummy(age) |>
  step_pca(starts_with("attacking"), num_comp = 1, prefix = "attacking") |>
  step_pca(starts_with("defending"), num_comp = 1, prefix = "defending") 

model <- boost_tree(trees = 100, tree_depth = 3) |>
  set_engine("xgboost") |>

model2 <-
  linear_reg(penalty = 0.1, mixture = 1) %>%

wfs <- 
    preproc = list(rec_pca),  
    models = list(model, model2), 
    cross = T  )


wfs_rs <-
    resamples = fifa_small_folds,
    control = control_grid(save_pred = TRUE,
                           parallel_over = "everything",
                           save_workflow = TRUE            
  ) )



wfs_stack <- 
  stacks() %>% 

blend_ens <- blend_predictions(wfs_stack, penalty = 10^seq(-2, 0, length = 10))

ens_fit <- fit_members(blend_ens)

The question is what would go into the explainer function?

maksymiuks commented 2 years ago

@Jeffrothschild Thanks for the update

The general rule is that explain needs object that predict-like function can be called on. It does not necessarily have to be a predict generic, any function that would extract predictions for new data from the model would work. Therefore, based on your example, ens_fit should go to the explain.

That being said, now I see where is the problem. explain_tidymodels assumes workflow metadata structure and checks the trained field that does not exist for model_stack. I'll address this one and get back to you with a development version

maksymiuks commented 2 years ago

@Jeffrothschild Hi!

Requested changes are subject to this PR and will be merged to master once automated checks pass. I'd love to hear your opinion if you could test it.

Jeffrothschild commented 2 years ago

@maksymiuks this looks fantastic, thank you!!!!

I will stress test it a bit more this week, but I really appreciate you making these updates.

@topepo may be a better person to test it out as well

simonpcouch commented 2 years ago

Will give this a look tomorrow or Tuesday! Thanks!

simonpcouch commented 2 years ago

Thumbs up from our end!

Only note here is that the default referenced here:

results in a somewhat uninformative default Model label. From the tests:

> explain_tidymodels(ens_fit, data = titanic_imputed, y = titanic_imputed$survived, verbose = FALSE)
Model label:  list 
Model class:  linear_stack,model_stack,list 
Data head  :
  gender age class    embarked  fare sibsp parch survived
1   male  42   3rd Southampton  7.11     0     0        0
2   male  13   3rd Southampton 20.05     0     2        0

Thanks for yall's work here. :)

maksymiuks commented 2 years ago

Thank you all for the support!

I'll investigate the issue with label and then send the newest version to CRAN

maksymiuks commented 2 years ago

@simonpcouch Regarding the weird default label, it's actually related to how DALEX handles default values. It takes the last value of the class object resulting in a list for that case. It happens there