giuseppec / iml

iml: interpretable machine learning R package
https://giuseppec.github.io/iml/
Other
490 stars 88 forks source link

Incompatible with xgboost models trained on xgb.DMatrix #29

Closed jmpanfil closed 5 years ago

jmpanfil commented 5 years ago

Is there a workaround to implement the Shapley function, for example, on xgboost models trained on xgb.DMatrix objects?

For example,

predictor <- Predictor$new(xgb.model, data = x, y = y,
                           predict.fun = function(object, newdata){
                             predict(object, newdata )})
shapley   <- Shapley$new(predictor, x.interest = x[1,], sample.size = 10, run = TRUE)

results in an error:

Error in xgb.DMatrix(newdata, missing = missing) : 
  xgb.DMatrix: does not support to construct from  list
keithhurley commented 5 years ago

the problem is that the prediction function for xgboost needs a xgb.DMatrix and not a data.frame.

try this:

predict.fun=function(object, newdata){ newData_x = xgb.DMatrix(data.matrix(newdata), missing = NA) results<-predict(model, newData_x) return(results) }

christophM commented 5 years ago

Thanks @keithhurley

jmpanfil commented 5 years ago

Apologies for my late response. Using @keithhurley solution appears to be working except I am getting a large amount of warnings.

lime <- LocalModel$new(predictor, x.interest = x[1,], k = 20)
There were 50 or more warnings (use warnings() to see the first 50)
warnings()
Warning messages:
1: In gower_work(x = x, y = y, pair_x = pair_x, pair_y = pair_y,  ... :
  skipping variable with zero or non-finite range.

That warning repeats however many times.

dwhdai commented 5 years ago

the problem is that the prediction function for xgboost needs a xgb.DMatrix and not a data.frame.

try this:

predict.fun=function(object, newdata){ newData_x = xgb.DMatrix(data.matrix(newdata), missing = NA) results<-predict(model, newData_x) return(results) }

Thanks for providing this work-around. Minor thing in predict.fun() - I think the function parameters should be function(model, newdata) to reflect the predict() call within the function.