thomasp85 / lime

Local Interpretable Model-Agnostic Explanations (R port of original Python package)
https://lime.data-imaginist.com/
Other
486 stars 110 forks source link

Reconstructing local prediction #146

Closed Make42 closed 5 years ago

Make42 commented 5 years ago

I am using lime in R with the code

library(MASS)
library(lime)
data(biopsy)

# First we'll clean up the data a bit
biopsy$ID <- NULL
biopsy <- na.omit(biopsy)
names(biopsy) <- c('clump thickness', 'uniformity of cell size', 
                   'uniformity of cell shape', 'marginal adhesion',
                   'single epithelial cell size', 'bare nuclei', 
                   'bland chromatin', 'normal nucleoli', 'mitoses',
                   'class')

set.seed(4)
test_set <- sample(seq_len(nrow(biopsy)), 4)
data_train = biopsy[-test_set,] %>% dplyr::select(-class)
class_train = biopsy[-test_set,] %>% .[["class"]] %>% factor
data_test = biopsy[test_set,] %>% dplyr::select(-class)
class_test = biopsy[test_set,] %>% .[["class"]] %>% factor
model = train(data_train, class_train, method="rf") # Random Forest

explainer <- lime(data_train, model, bin_continuous = TRUE, quantile_bins = FALSE)
explanation <- explain(data_test[1,], explainer, n_labels = 1, n_features = 4)

model_expl = explanation %>%
  dplyr::select(-starts_with("feature")) %>%
  filter(case == .$case[1]) %>%
  unique %>%
  mutate_if(is.numeric, as.character) %>%
  mutate_all(as.character) %>%
  gather(key, value)

feature_expl = explanation %>%
  dplyr::select(case, starts_with("feature")) %>%
  filter(case == .$case[1])

I get a model_prediction of 0.715906270331288 from the explanation. With an intercept of 0.114219195393416 I try to reconstruct the local approximation:

sum(feature_expl$feature_value * feature_expl$feature_weight) + 0.114219195393416

but get 2.155599 instead of 0.715906270331288. I read I require scaling, but I could not find out how to do the scaling properly. What do I need to do to reconstruct the local prediction?

thomasp85 commented 5 years ago

The local model is fitted with glmnet. You can see here how the model statistics are being extracted

https://github.com/thomasp85/lime/blob/efdc1ddba8ca05b9386ffd50f5354a109b4fd3da/R/lime.R#L58-L63

hope this helps