giuseppec / iml

iml: interpretable machine learning R package
https://giuseppec.github.io/iml/
Other
492 stars 87 forks source link

Bug: Features with more than one R class cause overlong feature.type vector #185

Closed christophM closed 2 years ago

christophM commented 2 years ago

Reproducible example:

library(mlr3)
library(mlr3learners)
library(mlr3pipelines)
library(iml)
# using pre-defined mlr3-tasks
task <- tsk("german_credit")

# create initial graph
graph <- Graph$new()
# adding learner PipeOp
graph$add_pipeop(lrn("classif.log_reg", predict_type = "prob"))

# graph <- graph %>>% po("threshold")
# graph$param_set$values$threshold.thresholds <- 0.5
# saving the graph as a GraphLearner
learner <- as_learner(graph)

# creating split for test and training data
# using default 80/20 split
train_ids <- sample(task$row_ids, task$nrow * 0.8)
test_ids <- setdiff(task$row_ids, train_ids)
# training the model
learner$train(task, row_ids = train_ids)
# predicting on the training and test data
train_pred <- learner$predict(task, row_ids = train_ids)
test_pred <- learner$predict(task, row_ids = test_ids)

# using iml for explaining
model <- Predictor$new(learner, data = task$data(), y = task$target_names)
feat_imp <- FeatureImp$new(model, loss = "ce", compare = "ratio")

# this one fails
feat_effect <- FeatureEffects$new(model, method = "pdp")
# Error in get.grid(private$getData()[, self$feature.name, with = FALSE],  :
# Assertion on 'length(features) == length(feature.type)' failed: Must be TRUE.

### this code was extracted from iml to understand why the assertion fails
dat <- task$data()
features <- colnames(dat)
length(features)

get.feature.type <- function(feature.class) {
  assertCharacter(feature.class)

  feature.types <- c(
    "numeric" = "numerical",
    "integer" = "numerical",
    "character" = "categorical",
    "factor" = "categorical",
    "ordered" = "categorical"
  )

  stopifnot(all(feature.class %in% names(feature.types)))
  feature.types[feature.class]
}

feature.type <- unlist(lapply(dat, function(x) {
    get.feature.type(class(x))
}))

### end of extract

# debugging commands

# assert in iml
assert_true(length(features) == length(feature.type))

length(features)
# 21

length(feature.type)
# 24

# showing classes of features

sapply(dat, function(x) class(x))

# $present_residence
# [1] "ordered" "factor"

# $number_credits
# [1] "ordered" "factor"
pat-s commented 2 years ago

Most likely fixed in #189