mayer79 / flashlight

Machine learning explanations
https://mayer79.github.io/flashlight/
GNU General Public License v2.0
22 stars 4 forks source link

H2O Models support #29

Closed coforfe closed 4 years ago

coforfe commented 4 years ago

Hello,

Firstly, thank you very much for your very comprehensive package an your very clear vignette that I see as a concise but good reference for model's explainability.

I am a little bit doubtful about how to use flashlight functionalities with h2o models. I have not seen it covered in your vignette.

With some kind of guideline I can test it and share results.

Thanks again! Carlos.

mayer79 commented 4 years ago

Hello Carlos

Thanks for looking into flashlight!

A good question. In theory, it works because flashlight only relies on predictions. In practical applications, it will only do the job for smallish data. Why? Most explainers change the data multiple times. This happens in R, not in Java (h2o). In these cases, the data set has to be sent to the h2o https service multiple times through a JSON blob, which is too inefficient.

I'd say for data up to 10'000 rows it is not such an issue. But nowadays, most data sets are much, much larger. Then, I'd recommend to use h2o's inbuilt tools.

However, I cannot outrule the possibility that a h2o.flashlight package will be written, doing all data manipulations in h2o world.

Here is an example that works smooth:

library(h2o)
library(dplyr)
library(MetricsWeighted)
library(flashlight)
library(caret)
library(ranger)

h2o.init()
h2o.no_progress()

data(cars)
str(cars)

undo_dummies <- function(df, cols) {
  factor(data.matrix(df[, cols]) %*% seq_along(cols), labels = cols)
}

no_yes <- function(x) {
  factor(x, 0:1, c("no", "yes"))
}

# Prepare data
cars <- cars %>% 
  mutate(Price = log(Price),
         Mileage = log(Mileage),
         Type = undo_dummies(., c("convertible", "coupe", "hatchback", "sedan", "wagon")),
         Made = undo_dummies(., c("Buick", "Cadillac", "Chevy", "Pontiac", "Saab", "Saturn"))) %>% 
  mutate_at(c("Cruise", "Sound", "Leather"), no_yes)

# Response and covariables
y <- "Price"
x <- c("Cylinder", "Doors", "Cruise", "Sound", "Leather", "Mileage", "Made")

# Data split
set.seed(1)
idx <- c(createDataPartition(cars[[y]], p = 0.7, list = FALSE))

tr <- cars[idx, c(y, x)]
te <- cars[-idx, c(y, x)]

# Fit the models with caret (in reality, they would need some tuning)
fit_lm <- h2o.glm(x, y, as.h2o(tr))
fit_rf <- h2o.randomForest(x, y, as.h2o(tr))

# flashlights
pred_fun <- function(mod, X) as.vector(unlist(h2o.predict(mod, as.h2o(X))))
fl_lm <- flashlight(model = fit_lm, label = "lm", predict_function = pred_fun)
fl_rf <- flashlight(model = fit_rf, label = "rf", predict_function = pred_fun)

fls <- multiflashlight(list(fl_lm, fl_rf), y = y, data = te, 
                       metrics = list(RMSE = rmse, `R-Squared` = r_squared))

# Explaining the models

# Performance
light_performance(fls) %>% 
  plot(fill = "darkred")

# Importance
imp <- light_importance(fls, m_repetitions = 4) 
plot(imp, fill = "darkred")

# Effects
# Individual conditional expectations (ICE). Using a seed guarantees the same observations across models
light_ice(fls, v = "Cylinder", n_max = 100, seed = 54) %>% 
  plot(alpha = 0.1)

# Partial dependence profiles
light_profile(fls, v = "Cylinder") %>% 
  plot()

light_profile(fls, v = "Cylinder", by = "Leather") %>% 
  plot()

# Accumulated local effects
light_profile(fls, v = "Cylinder", type = "ale") %>% 
  plot()

# M-Plots
light_profile(fls, v = "Mileage", type = "predicted") %>% 
  plot()

# Response profiles, prediction profiles, partial dependence in one
eff <- light_effects(fls, v = "Cylinder") 
eff %>% 
  plot() %>% 
  plot_counts(eff, alpha = 0.3)

# Interaction strength
inter <- light_interaction(fls, v = most_important(imp, 4), pairwise = TRUE, n_max = 50, seed = 63) 
plot(inter, fill = "darkred")

# Variable contribution breakdown for single observation
light_breakdown(fls, new_obs = te[1, ]) %>% 
  plot(size = 3, facet_ncol = 2)

# Global surrogate
light_global_surrogate(fls) %>% 
  plot()
coforfe commented 4 years ago

Hello Michael,

Thanks a lot for your clear and detailed response!. I appreciate it. The example is very thoroughly.

Well, that would be nice to have flashlight functionality within H2O. What is currently available is quite limited and with other packages happens what you say that for some 100K rows that back and forth data exchange make them impractical.

In any case, I will use your package despite this limitation.

Thanks again and I will continue following flashlight improvements. Carlos.

mayer79 commented 4 years ago

I will have a look at the data query tools of h2o and think about it :-)

coforfe commented 4 years ago

Thanks, that sounds pretty promising! :-).

Some pointers for your reference...

In DALEX (through DALEXtra) there is an explainer for h2o models explain.h2o() but my impression is that it is not optimized in the way you suggest. At least, it is what I've experienced.

And regarding what is already available in H2O, AFIK, they have a function to get SHAP values for GBM and XGBoost models h2o.predict_contributions() and another one for partial plots h2o.partialPlot().

Another help could be to accelerate the calculations is with what iml has added recently in terms of parallel computations. I do not know if your package can get any advantage out of it.

Thanks again! Carlos.