ModelOriented / DALEX

moDel Agnostic Language for Exploration and eXplanation
https://dalex.drwhy.ai
GNU General Public License v3.0
1.38k stars 166 forks source link

Error when examining aggregated SHAP values in R using tidymodels #570

Open e05bf027 opened 2 months ago

e05bf027 commented 2 months ago

I have trained and fitted a C5.0 model with adaptive boosting to predict a binary outcome. This was performed using tidymodels and the parsnip package.

I am now trying to evaluate the model in more detail using DALEX and DALEXtra.

The outcome variable is "yes/no" so I have converted these to binary "1/0". I have created an explainer using the following code with the DALEXtra package:

explain_c50 <- 
  explain_tidymodels(model = c5_final_fit, 
                     data = final_data_test, 
                     y = y_test, 
                     verbose = F)

I have also created an explainer using only the DALEX package, and encounter the exact same issue:

custom_predict <- function(object, newdata) {
  pred <- 
    predict(object, newdata, type = 'prob')[1] %>% 
    pull(.pred_F)

  return(pred)
}

DALEX_explainerTest <- DALEX::explain(model = c5_final_fit, 
                                  data = final_data_test,
                                  predict_function = custom_predict,
                                  y = y_test, 
                                  label = "c50-train")

These run without error. However, when I try to run a command to interrogate additive SHAP values I get the following error:

DALEX::shap_aggregated(explainer = explain_c50, new_observations = final_data_test[1:10, ])

Error message is always the same, no matter what I try:

Error in `[<-.data.frame`(`*tmp*`, , candidate, value = list(pure_model_prediction = list( : 
  replacement element 1 is a matrix/data frame of 1 row, need 21

I have tried processing the data slightly differently, tried using the training data (rather than test), etc. The error message is always this way so I assume I am making a fairly fundamental error. I cannot understand why I am getting this error.

Does anyone reading this have any idea?

e05bf027 commented 2 months ago

@hbaniecki I see there is an Inavalid! label added to my post. Have I posted incorrectly?

hbaniecki commented 2 months ago

Hi, thanks for raising the issue. Invalid is meant to denote a bug / something not working as intended.

e05bf027 commented 2 months ago

Oh I understand now. As you might guess, I am an enthusiastic amateur in this world. I am a medical doctor exploring machine learning in Critical Care. Thank you for clarifying!

mayer79 commented 2 months ago

@hbaniecki According to https://docs.github.com/en/issues/using-labels-and-milestones-to-track-work/managing-labels, "invalid" means that the issue/PR is no longer relevant. Maybe we can replace it by "Bug"?

hbaniecki commented 2 months ago

Labels are described in https://github.com/ModelOriented/DALEX/labels and you can hover over them to read the description. I wouldn't change their names, even if only because it will override labels on all previous issues.

maksymiuks commented 2 months ago

Hi @e05bf027

Thank you for raising the issue. Unfortunately, I'm afraid that without a reproducible example, I won't be able to help. I've tried to create a parsnip C5 model using some dummy dataset, and it worked without any issue both using DALEXtra and the custom function you've provided

library(parsnip)
library(tidymodels)
library(rules)
library(DALEXtra)

data <- iris
data$Species <- as.factor(ifelse(data$Species == "setosa", "yes", "no"))

model <- C5_rules(
  trees = 1,
  min_n = 1
) |>
  set_engine("C5.0") |>
  set_mode("classification") |>
  fit(Species ~ ., data = data)

explain_c50 <- 
  explain_tidymodels(model = model, 
                     data = data, 
                     y = as.numeric(as.factor(data$Species)) - 1)

shap <- DALEX::shap_aggregated(explainer = explain_c50, new_observations = data[1:10, ])

custom_predict <- function(object, newdata) {
  pred <- 
    predict(object, newdata, type = 'prob') %>% 
    pull(.pred_yes)

  return(pred)
}

DALEX_explainerTest <- DALEX::explain(model = model, 
                                      data = data,
                                      predict_function = custom_predict,
                                      y = as.numeric(as.factor(data$Species)) - 1, 
                                      label = "c50-train")

DALEX::shap_aggregated(explainer = explain_c50, new_observations = data[1:10, ])

can you please provide some more details about your issue?