giuseppec / iml

iml: interpretable machine learning R package
https://giuseppec.github.io/iml/
Other
492 stars 87 forks source link

custom predictor produces different results for feature importance #174

Closed AleBitetto closed 3 years ago

AleBitetto commented 3 years ago

Hi,

I'm trying to use a custom predictor so I'm making some basic test to check reproducibility and prediction validity. I tried a very simple linear model building my custom predictor and comparing the results with the same linear model with default predictor function. The predicted values of the entire dataset with different predictor match, but if I run FeatImp I get very different results. Any idea?

library(iml)
library(dplyr)

set.seed(42)
data("Boston", package = "MASS")
mod <- lm(medv ~ ., data = Boston)
X <- Boston[which(names(Boston) != "medv")]
predictor <- Predictor$new(mod, data = X, y = Boston$medv, type = NULL, class = NULL)

#  Create custom predict function
custom_pred_fun <- function(model, newdata){
  results <- predict(model, data = newdata) %>% as.numeric()
  return(results)
}

# example of prediction output
custom_pred_fun(mod, X) %>% head()

# define custom predictor
custom_predictor <- Predictor$new(
  model = mod, 
  data = X, 
  y = Boston$medv, 
  predict.fun = custom_pred_fun,
  type = NULL,
  class = NULL
)

# check predictions
check_prediction = custom_predictor$predict(X) %>%
  rename(custom = pred) %>%
  bind_cols(predictor$predict(X)) %>%
  mutate(difference = abs(custom-pred))
range(check_prediction$difference)

# test feature importance
set.seed(42)
imp <- FeatureImp$new(predictor, loss = "mae", n.repetitions = 5, compare = "difference") 
imp$results
set.seed(42)
imp2 <- FeatureImp$new(custom_predictor, loss = "mae", n.repetitions = 5, compare = "difference") 
imp2$results

first output is

> imp$results
   feature importance.05    importance importance.95 permutation.error
1    lstat  2.1493884252  2.1984705726  2.3359516132          5.469333
2      dis  1.4179569361  1.5727117378  1.6880731883          4.843575
3      rad  0.9891360131  1.1130780880  1.2701662081          4.383941
4       rm  0.8710001468  0.9639893200  1.1299396198          4.234852
5      nox  0.7931062607  0.8750023005  0.9870623982          4.145865
6      tax  0.8147508309  0.8417232862  1.0065056619          4.112586
7  ptratio  0.7251740844  0.8358012462  0.8561883645          4.106664
8       zn  0.1673241230  0.1999924241  0.2861123450          3.470855
9    black  0.1036532944  0.1777958771  0.2262358726          3.448659
10    crim  0.1551785270  0.1753026006  0.2175427460          3.446165
11    chas  0.0227621927  0.0573480436  0.0644136832          3.328211
12   indus -0.0007655180  0.0052707037  0.0102879482          3.276134
13     age -0.0008699595 -0.0006881205  0.0002567683          3.270175

second output is

> imp2$results
   feature importance.05 importance importance.95 permutation.error
1       rm      5.740759   6.312130      6.491682          9.582993
2      rad      5.866348   6.216290      6.515221          9.487153
3     chas      6.032097   6.178665      6.312786          9.449528
4      dis      6.082575   6.172672      6.383786          9.443535
5  ptratio      6.044840   6.157038      6.326485          9.427901
6       zn      6.049599   6.135749      6.301742          9.406611
7    lstat      5.934761   6.133459      6.561748          9.404321
8     crim      6.061094   6.107696      6.241122          9.378559
9      age      5.562145   6.057008      6.312441          9.327871
10     nox      5.814618   6.010638      6.298920          9.281501
11     tax      5.940079   5.976696      6.357833          9.247559
12   black      5.861750   5.897589      6.191795          9.168452
13   indus      5.849306   5.872035      6.094271          9.142898

thanks a lot

christophM commented 3 years ago

Hi, thanks for reporting this. The error is in the line results <- predict(model, data = newdata) %>% as.numeric(). predict() used on the linear model automatically calls the predict.lm function. And the argument for the new data set is called newdata and not data. So what happens then is that the lm always returns the default prediction from the training data. If you replace the line with results <- predict(model, newdata = newdata) %>% as.numeric() it works