giuseppec / iml

iml: interpretable machine learning R package
https://giuseppec.github.io/iml/
Other
489 stars 88 forks source link

LocalModel `$predict()` does not allow single row `data.frame` as input #204

Closed dandls closed 2 months ago

dandls commented 1 year ago

Hi, after I fitted a local surrogate model with LocalModel to explain the prediction of a point of interest, I wanted to compare the prediction of the local model to the prediction of the underlying model for this point. Unfortunately, the $predict() method throws an error if the input is a data.frame with only one row. I use

library("randomForest")
library("iml")
set.seed(123L)

rf = randomForest(Species ~ ., data = iris)
mod = Predictor$new(rf, data = iris)

# Local model
x.interest = iris[1,]
x.interest$Species = NULL
locmod = LocalModel$new(mod, x.interest)
locmod$results
locmod$predict(x.interest)
# Error in names(x) <- value :
#   'names' attribute [3] must be the same length as the vector [1]

The error does not occur if the input is a data.frame with more than one row.

locmod$predict(rbind(x.interest, x.interest))
# setosa  versicolor    virginica
# 1 0.9938849 0.006115104 5.256403e-10
# 2 0.9938849 0.006115104 5.256403e-10
giuseppec commented 2 months ago

should work now