ModelOriented / iBreakDown

Break Down with interactions for local explanations (SHAP, BreakDown, iBreakDown)
https://ModelOriented.github.io/iBreakDown/
GNU General Public License v3.0
82 stars 15 forks source link

local_attributions fails for classification: incorrect number of subscripts on matrix #69

Closed agilebean closed 4 years ago

agilebean commented 4 years ago

For classification, local_attributions() returns the error:

Error in contribution[nrow(contribution), ] <- cummulative[nrow(contribution),  : 
  incorrect number of subscripts on matrix

One hint for the root cause might be the warning message thrown by the explainer - it tries to calculate numeric residuals which of course it cannot do:

      DALEX.explainer <- DALEX::explain(
        model = model_object,
        data = features,
        y = training.set$.outcome == TARGET.VALUE,
        label = paste(model_object$method, " model"),
        colorize = TRUE
      )

  A new explainer has been created!  
Warning message:
In mean.default(residuals) :
  argument is not numeric or logical: returning NA

Reproducible example:

random.case <- structure(list(anger = 0.166666666666667, anticipation = 0, disgust = 0.166666666666667, 
    fear = 0.166666666666667, joy = 0, negative = 0.25, positive = 0.0833333333333333, 
    sadness = 0.0833333333333333, surprise = 0.0833333333333333, 
    trust = 0), class = "data.frame", row.names = c(NA, -1L))

training.set <- structure(list(.outcome = structure(c(3L, 4L, 5L, 4L, 4L, 5L, 
5L, 4L, 3L, 3L, 3L, 5L, 4L, 3L, 3L, 1L, 4L, 3L, 4L, 5L, 3L, 2L, 
5L, 5L, 5L), .Label = c("1", "2", "3", "4", "5"), class = "factor"), 
    anger = c(0, 0.0434782608695652, 0, 0, 0, 0.1, 0, 0.037037037037037, 
    0.0192307692307692, 0, 0, 0, 0, 0.0673076923076923, 0.181818181818182, 
    0.0408163265306122, 0, 0, 0, 0.0285714285714286, 0.0526315789473684, 
    0.0952380952380952, 0, 0.0441176470588235, 0), anticipation = c(0.333333333333333, 
    0.217391304347826, 0.125, 0.15, 0.2, 0.2, 0.217391304347826, 
    0.111111111111111, 0.173076923076923, 0.166666666666667, 
    0.111111111111111, 0.157894736842105, 0.214285714285714, 
    0.115384615384615, 0.0909090909090909, 0.0408163265306122, 
    0, 0.166666666666667, 0, 0.114285714285714, 0.184210526315789, 
    0.0476190476190476, 0.133333333333333, 0.102941176470588, 
    0.176470588235294), disgust = c(0, 0, 0, 0, 0, 0, 0, 0.0185185185185185, 
    0.0192307692307692, 0.0833333333333333, 0.0740740740740741, 
    0, 0, 0.0288461538461538, 0, 0.0204081632653061, 0, 0, 0.111111111111111, 
    0, 0, 0.0952380952380952, 0, 0.0294117647058824, 0), fear = c(0, 
    0.0434782608695652, 0, 0.05, 0, 0, 0, 0.0185185185185185, 
    0, 0, 0, 0, 0, 0.0673076923076923, 0, 0.0408163265306122, 
    0, 0.0833333333333333, 0.111111111111111, 0, 0.0263157894736842, 
    0.0952380952380952, 0, 0.0294117647058824, 0), joy = c(0, 
    0.130434782608696, 0.166666666666667, 0.15, 0.233333333333333, 
    0.2, 0.173913043478261, 0.166666666666667, 0.0961538461538462, 
    0.166666666666667, 0.037037037037037, 0.210526315789474, 
    0.214285714285714, 0.0961538461538462, 0.181818181818182, 
    0.0204081632653061, 0.333333333333333, 0.0833333333333333, 
    0.222222222222222, 0.2, 0.105263157894737, 0.0952380952380952, 
    0.2, 0.147058823529412, 0.176470588235294), negative = c(0, 
    0.0869565217391304, 0.0833333333333333, 0.1, 0, 0, 0, 0.0555555555555556, 
    0.0769230769230769, 0.166666666666667, 0.0740740740740741, 
    0.0526315789473684, 0.0714285714285714, 0.105769230769231, 
    0.181818181818182, 0.204081632653061, 0, 0.166666666666667, 
    0.222222222222222, 0.0285714285714286, 0.105263157894737, 
    0.19047619047619, 0, 0.102941176470588, 0.0294117647058824
    ), positive = c(0.333333333333333, 0.217391304347826, 0.291666666666667, 
    0.4, 0.3, 0.3, 0.347826086956522, 0.333333333333333, 0.326923076923077, 
    0.25, 0.259259259259259, 0.315789473684211, 0.285714285714286, 
    0.240384615384615, 0.181818181818182, 0.244897959183673, 
    0.333333333333333, 0.25, 0.222222222222222, 0.4, 0.342105263157895, 
    0.238095238095238, 0.4, 0.235294117647059, 0.352941176470588
    ), sadness = c(0.333333333333333, 0.0434782608695652, 0.0416666666666667, 
    0, 0, 0, 0, 0.0185185185185185, 0.0576923076923077, 0, 0.0740740740740741, 
    0, 0, 0.0480769230769231, 0.0909090909090909, 0.142857142857143, 
    0, 0, 0.111111111111111, 0, 0.0526315789473684, 0.0952380952380952, 
    0, 0.0441176470588235, 0.0294117647058824), surprise = c(0, 
    0.0434782608695652, 0.0833333333333333, 0.05, 0.0666666666666667, 
    0, 0.0434782608695652, 0.037037037037037, 0.0192307692307692, 
    0, 0.111111111111111, 0.0526315789473684, 0, 0.0865384615384615, 
    0, 0.0408163265306122, 0, 0, 0, 0.0285714285714286, 0.0526315789473684, 
    0, 0.0666666666666667, 0.0735294117647059, 0.0294117647058824
    ), trust = c(0, 0.173913043478261, 0.208333333333333, 0.1, 
    0.2, 0.2, 0.217391304347826, 0.203703703703704, 0.211538461538462, 
    0.166666666666667, 0.259259259259259, 0.210526315789474, 
    0.214285714285714, 0.144230769230769, 0.0909090909090909, 
    0.204081632653061, 0.333333333333333, 0.25, 0, 0.2, 0.0789473684210526, 
    0.0476190476190476, 0.2, 0.191176470588235, 0.205882352941176
    )), row.names = c(NA, 25L), class = "data.frame")

model.rf <- caret::train(
  form = .outcome ~ .,
  data = training.set,
  method = "rf", 
  trControl = trainControl(
    method = "repeatedcv", number = 5, repeats = 5)
)

target <- training.set$.outcome
features <- training.set %>% select(-.outcome)

TARGET.VALUE <- "1"

DALEX.explainer <- DALEX::explain(
        model = model.rf,
        data = features,
        y = target == TARGET.VALUE,
        label = paste(model_object$method, " model"),
        colorize = TRUE
  )

DALEX.attribution <- DALEX.explainer %>%
        iBreakDown::local_attributions(random.case) 
hbaniecki commented 4 years ago

Hi @agilebean! Can you provide the model please? It might be the problem with predict_function. Your code works for me DALEX v0.9.4 && iBreakDown v0.9.9 :

library(dplyr)
random.case <- structure(list(anger = 0.166666666666667, anticipation = 0, disgust = 0.166666666666667, 
                              fear = 0.166666666666667, joy = 0, negative = 0.25, positive = 0.0833333333333333, 
                              sadness = 0.0833333333333333, surprise = 0.0833333333333333, 
                              trust = 0), class = "data.frame", row.names = c(NA, -1L))

training.set <- structure(list(.outcome = structure(c(3L, 4L, 5L, 4L, 4L, 5L, 
                                                      5L, 4L, 3L, 3L, 3L, 5L, 4L, 3L, 3L, 1L, 4L, 3L, 4L, 5L, 3L, 2L, 
                                                      5L, 5L, 5L), .Label = c("1", "2", "3", "4", "5"), class = "factor"), 
                               anger = c(0, 0.0434782608695652, 0, 0, 0, 0.1, 0, 0.037037037037037, 
                                         0.0192307692307692, 0, 0, 0, 0, 0.0673076923076923, 0.181818181818182, 
                                         0.0408163265306122, 0, 0, 0, 0.0285714285714286, 0.0526315789473684, 
                                         0.0952380952380952, 0, 0.0441176470588235, 0), anticipation = c(0.333333333333333, 
                                                                                                         0.217391304347826, 0.125, 0.15, 0.2, 0.2, 0.217391304347826, 
                                                                                                         0.111111111111111, 0.173076923076923, 0.166666666666667, 
                                                                                                         0.111111111111111, 0.157894736842105, 0.214285714285714, 
                                                                                                         0.115384615384615, 0.0909090909090909, 0.0408163265306122, 
                                                                                                         0, 0.166666666666667, 0, 0.114285714285714, 0.184210526315789, 
                                                                                                         0.0476190476190476, 0.133333333333333, 0.102941176470588, 
                                                                                                         0.176470588235294), disgust = c(0, 0, 0, 0, 0, 0, 0, 0.0185185185185185, 
                                                                                                                                         0.0192307692307692, 0.0833333333333333, 0.0740740740740741, 
                                                                                                                                         0, 0, 0.0288461538461538, 0, 0.0204081632653061, 0, 0, 0.111111111111111, 
                                                                                                                                         0, 0, 0.0952380952380952, 0, 0.0294117647058824, 0), fear = c(0, 
                                                                                                                                                                                                       0.0434782608695652, 0, 0.05, 0, 0, 0, 0.0185185185185185, 
                                                                                                                                                                                                       0, 0, 0, 0, 0, 0.0673076923076923, 0, 0.0408163265306122, 
                                                                                                                                                                                                       0, 0.0833333333333333, 0.111111111111111, 0, 0.0263157894736842, 
                                                                                                                                                                                                       0.0952380952380952, 0, 0.0294117647058824, 0), joy = c(0, 
                                                                                                                                                                                                                                                              0.130434782608696, 0.166666666666667, 0.15, 0.233333333333333, 
                                                                                                                                                                                                                                                              0.2, 0.173913043478261, 0.166666666666667, 0.0961538461538462, 
                                                                                                                                                                                                                                                              0.166666666666667, 0.037037037037037, 0.210526315789474, 
                                                                                                                                                                                                                                                              0.214285714285714, 0.0961538461538462, 0.181818181818182, 
                                                                                                                                                                                                                                                              0.0204081632653061, 0.333333333333333, 0.0833333333333333, 
                                                                                                                                                                                                                                                       0.222222222222222, 0.2, 0.105263157894737, 0.0952380952380952, 
                                                                                                                                                                                                                                                              0.2, 0.147058823529412, 0.176470588235294), negative = c(0, 
                                                                                                                                                                                                                                                                                                                       0.0869565217391304, 0.0833333333333333, 0.1, 0, 0, 0, 0.0555555555555556, 
                                                                                                                                                                                                                                                                                                                       0.0769230769230769, 0.166666666666667, 0.0740740740740741, 
                                                                                                                                                                                                                                                                                                                       0.0526315789473684, 0.0714285714285714, 0.105769230769231, 
                                                                                                                                                                                                                                                                                                                       0.181818181818182, 0.204081632653061, 0, 0.166666666666667, 
                                                                                                                                                                                                                                                                                                                       0.222222222222222, 0.0285714285714286, 0.105263157894737, 
                                                                                                                                                                                                                                                                                                                       0.19047619047619, 0, 0.102941176470588, 0.0294117647058824
                                                                                                                                                                                                                                                              ), positive = c(0.333333333333333, 0.217391304347826, 0.291666666666667, 
                                                                                                                                                                                                                                                                              0.4, 0.3, 0.3, 0.347826086956522, 0.333333333333333, 0.326923076923077, 
                                                                                                                                                                                                                                                                              0.25, 0.259259259259259, 0.315789473684211, 0.285714285714286, 
                                                                                                                                                                                                                                                                              0.240384615384615, 0.181818181818182, 0.244897959183673, 
                                                                                                                                                                                                                                                                              0.333333333333333, 0.25, 0.222222222222222, 0.4, 0.342105263157895, 
                                                                                                                                                                                                                                                                              0.238095238095238, 0.4, 0.235294117647059, 0.352941176470588
                                                                                                                                                                                                                                                              ), sadness = c(0.333333333333333, 0.0434782608695652, 0.0416666666666667, 
                                                                                                                                                                                                                                                                             0, 0, 0, 0, 0.0185185185185185, 0.0576923076923077, 0, 0.0740740740740741, 
                                                                                                                                                                                                                                                                             0, 0, 0.0480769230769231, 0.0909090909090909, 0.142857142857143, 
                                                                                                                                                                                                                                                                             0, 0, 0.111111111111111, 0, 0.0526315789473684, 0.0952380952380952, 
                                                                                                                                                                                                                                                                             0, 0.0441176470588235, 0.0294117647058824), surprise = c(0, 
                                                                                                                                                                                                                                                                                                                                      0.0434782608695652, 0.0833333333333333, 0.05, 0.0666666666666667, 
                                                                                                                                                                                                                                                                                                                                      0, 0.0434782608695652, 0.037037037037037, 0.0192307692307692, 
                                                                                                                                                                                                                                                                                                                                      0, 0.111111111111111, 0.0526315789473684, 0, 0.0865384615384615, 
                                                                                                                                                                                                                                                                                                                                      0, 0.0408163265306122, 0, 0, 0, 0.0285714285714286, 0.0526315789473684, 
                                                                                                                                                                                                                                                                                                                                      0, 0.0666666666666667, 0.0735294117647059, 0.0294117647058824
                                                                                                                                                                                                                                                                             ), trust = c(0, 0.173913043478261, 0.208333333333333, 0.1, 
                                                                                                                                                                                                                                                                                          0.2, 0.2, 0.217391304347826, 0.203703703703704, 0.211538461538462, 
                                                                                                                                                                                                                                                                                          0.166666666666667, 0.259259259259259, 0.210526315789474, 
                                                                                                                                                                                                                                                                                          0.214285714285714, 0.144230769230769, 0.0909090909090909, 
                                                                                                                                                                                                                                                                                          0.204081632653061, 0.333333333333333, 0.25, 0, 0.2, 0.0789473684210526, 
                                                                                                                                                                                                                                                                                          0.0476190476190476, 0.2, 0.191176470588235, 0.205882352941176
                                                                                                                                                                                                                                                                             )), row.names = c(NA, 25L), class = "data.frame")
target <- training.set$.outcome
features <- training.set %>% select(-.outcome)

TARGET.VALUE <- "1"

colnames(training.set)[1] <- "outcome"
model_object <- lm(outcome==TARGET.VALUE~., data = training.set)

DALEX.explainer <- DALEX::explain(
  model = model_object,
  data = features,
  y = training.set$outcome == TARGET.VALUE,
  label = paste(model_object$method, " model"),
  colorize = TRUE
)

DALEX.attribution <- DALEX.explainer %>% iBreakDown::local_attributions(random.case) 
DALEX.attribution
agilebean commented 4 years ago

Ha, crossing thoughts - I just included the model in the description!

agilebean commented 4 years ago

I just verified I had iBreakDown_0.9.9 and only 1 subrelease number below for DALEX, i.e. DALEX_0.9.3

agilebean commented 4 years ago

I just ran it again with the same error. However, when I run the same analysis - but with the model trained as regression instead of classification, it WORKS! Double checked just now.

agilebean commented 4 years ago

@hbaniecki Did you try it with the model I specified in the description above?

hbaniecki commented 4 years ago

Yes, it is a weird problem with data.frame/matrix behavior. I believe that this part https://github.com/ModelOriented/iBreakDown/blob/43b6e0bf9789d39b0d740a6faba7641e36fe7868/R/local_attributions.R#L252-L255 can be handled better (to fix).

While running your example, there is a red warning (in the explainer output) saying that predict_function returns probabilities for multiple classes. For now, if you want to use local_attributions for one class (e.g. target = "1"), you can use a custom predict_function and pass it to the explainer.

custom_predict_caret_oneclass <- function(model, data, target = "1") {
  return(predict(model, data, type = "prob")[, target])
}

DALEX.explainer <- DALEX::explain(
  model = model.rf,
  data = features,
  y = target == TARGET.VALUE,
  predict_function = custom_predict_caret_oneclass,
  label = paste(model.rf$method, " model"),
  colorize = TRUE
)

DALEX.attribution <- DALEX.explainer %>%
  iBreakDown::local_attributions(random.case)

DALEX.attribution
agilebean commented 4 years ago

Great analysis. Thanks for the oneclass predict_function code! But that's a bummer, I need it for a publication. Speaking of which, this issue on numbers on plots is extremely important for publications.

pbiecek commented 4 years ago

Thanks, there was a problem in the predict returns data.frame instead of matrix. It is solved in the latest DALEX in the ema branch (will be on master on the beginning of the week and on CRAN in a week).

In the meantime you can use user defined predict_function

DALEX.explainer <- DALEX::explain(
  model = model.rf,
  data = features,
  y = target == TARGET.VALUE,
  predict_function = function(m,x) as.matrix(predict(m, newdata = x, type = "prob")),
  label = paste(model_object$method, " model"),
  colorize = TRUE
)
pbiecek commented 4 years ago

this is now fixed with the latest DALEX starting with 0.9.8 as in https://github.com/ModelOriented/DALEX/tree/DALEX_1.0_ema_version

agilebean commented 4 years ago

I can confirm it works now - just ran local_attributions() on a classification dataset. Wonderful. Returns this plot: image