dandls / moc

Multi-Objective Counterfactuals
Other
40 stars 12 forks source link

Assertion on 'X' failed: Must be of type 'data.frame', not 'matrix'. #5

Closed charlie9526 closed 3 years ago

charlie9526 commented 3 years ago

I tried the moc with xgboost model. But the problem is in the Predictor variable.

"Predictor$new" request a data frame as the "data" parameter. But the xgboost model request a matrix type data. If not, xgb cannot convert that as a Dmatrix. How to handle this issue,? If you have any counterfactual explanation example with xgboost it will be really helpful. Thank you!

dandls commented 3 years ago

Yes, we also used a xgboost model in our benchmark study. To overcome the matrix problem we use mlr/mlrCPO. Here is a short example with the mlr::pid.task.

###-- Setup ----
# To run MOC
# load `iml` and `counterfactuals` like "normal" packages.
# in the future this would just be library("counterfactuals").
devtools::load_all("../iml", export_all = FALSE)
devtools::load_all("../counterfactuals", export_all = FALSE)

set.seed(111)
library("mlr")
library("mlrCPO")
library("ggplot2")

###---- Tune & train model ----
task = mlr::pid.task
df = getTaskData(pid.task)
lrn = makeLearner("classif.xgboost", predict.type = "prob")
lrn = cpoDummyEncode() %>>% cpoScaleRange() %>>%  lrn
param.set = pSS(
  nrounds: numeric[0, log(1000)] [[trafo = function(x) round(exp(x))]]
)

TUNEITERS = 10L
RESAMPLING = cv5
ctrl = makeTuneControlRandom(maxit = TUNEITERS * length(param.set$pars))
lrn.tuning = makeTuneWrapper(lrn, RESAMPLING, list(mlr::acc), param.set, ctrl, show.info = FALSE)
res = tuneParams(lrn, task, RESAMPLING, par.set = param.set, control = ctrl,
  show.info = FALSE)
lrn = setHyperPars2(lrn, res$x)
model = mlr::train(lrn, task) 

###--- Define a predictor as input for the counterfactual method ----
pred = Predictor$new(model = model, data = df, class = "neg",
  conditional = FALSE)
ctr = partykit::ctree_control(maxdepth = 5L)

###---- Compute counterfactuals ----
x.interest = df[1,]
pred$predict(x.interest)

cf = Counterfactuals$new(predictor = pred, 
  x.interest = x.interest, 
  target = c(0.5, 1), epsilon = 0, generations = 20)

# Overview
cf$results$counterfactuals.diff
cf$results$counterfactuals

###---- Plots ----
cf$plot_parallel(plot.x.interest = FALSE) 
cf$plot_hv()

The key step is that you define a xgboost model with integrated preprocessing steps for the data (here: cpoDummyEncode(), cpoScaleRange()).

Please increase the number of generations and tuning iterations in your own experiments for more reliable results. Let me know if that works for you.

charlie9526 commented 3 years ago

I really appreciate your quick reply. It worked for me!