ModelOriented / kernelshap

Different SHAP algorithms
https://modeloriented.github.io/kernelshap/
GNU General Public License v2.0
36 stars 7 forks source link

Does kernelshap surport long short-term memory (LSTM)? #106

Closed Yelong-Z closed 10 months ago

Yelong-Z commented 10 months ago

I try to explain long short-term memory (LSTM) model with keras package, but does not work. The codes are: `library(keras) library(kernelshap) set.seed(123) x <- matrix(runif(100), ncol = 5) y <- runif(20)

n_samples <- 20 n_timesteps <- 2 n_features <- 5

x_samples <- array(0, dim = c(n_samples, n_timesteps, n_features)) y_samples <- array(0, dim = c(n_samples, 1))

for (i in 1:(n_samples - n_timesteps + 1)) { x_samples[i,,] <- x[i:(i + n_timesteps - 1),] y_samples[i,] <- y[i + n_timesteps - 1] }

model <- keras_model_sequential() model %>% layer_lstm(units = 32, input_shape = c(n_timesteps, n_features)) %>% layer_dense(units = 1) %>% compile(loss = 'mean_squared_error',optimizer = optimizer_adam())

model %>% fit(x_samples, y_samples, epochs = 20, batch_size = 1, verbose = 0)

ks <- kernelshap(model, X=x_samples, bg_X = x_samples, verbose = F)`

Error in kernelshap.default(model, X = x_samples, bg_X = x_samples, verbose = F) : is.matrix(X) || is.data.frame(X) is not TRUE

mayer79 commented 10 months ago

You would need to pass a prediction function that maps a matrix/data.frame to your higher dimensional array. Since data.frames are flexible (e.g. squeeze multiple values into a cell), it might actually work!

mayer79 commented 10 months ago

Like this:

# Turn data into 2d matrix
M <- apply(x_samples, c(1, 3), FUN = list)
colnames(M) <- paste0("X", 1:ncol(M))

# Will revert
mat2array <- function(M) {
  aperm(apply(M, 1:2, unlist), c(2, 1, 3))
}

# Custom predict function
pf <- function(m, X) {
  predict(m, mat2array(X), verbose = FALSE, batch_size = 10)
}

ks <- kernelshap(model, X=M, bg_X = M, pred_fun = pf)
ks

# SHAP values of first 2 observations:
#                X1         X2           X3          X4          X5
# [1,] -0.009122247 0.05454801 -0.015948904 -0.01223268 0.004208181
# [2,]  0.015485218 0.02646689  0.004267152 -0.04263529 0.003057882

library(shapviz)

sv <- shapviz(ks)
sv_importance(sv)
sv_dependence(sv, "X1")  # Fails. Need to pass a smart X into shapviz() for visualization
image

For dependence plots, the x values would need to be atomic. Maybe there is a way to describe the time series behind a cell, e.g., using the mean?