ModelOriented / treeshap

Compute SHAP values for your tree-based models using the TreeSHAP algorithm
https://modeloriented.github.io/treeshap/
GNU General Public License v3.0
79 stars 23 forks source link

Error in ranger.unify() #39

Open Novampe opened 7 months ago

Novampe commented 7 months ago

I tried to unify my binary classification rf model trained by ranger. There is an error Warning message: In Ops.factor(get("Prediction"), n) : ‘/’ not meaningful for factors I don't know how to get rid of it and make my code run. Any idea about this error would be appreciated.

mayer79 commented 7 months ago

Please add a minimal reproducible example.

nlebovits commented 3 months ago

@mayer79 I've been running into the same issue, plus an error where treeshap doesn't recognize a model_unified object:

library(treeshap)
library(ranger)
library(dplyr)
library(mlr)

set.seed(1)
data(iris)

dat_prepared <- iris
dat_prepared$Species <- as.factor(dat_prepared$Species)
predictors <- dat_prepared %>% select(-Species)
predictors_encoded <- createDummyFeatures(predictors)
dat_encoded <- cbind(predictors_encoded, Species = dat_prepared$Species)

rf <- ranger(Species ~ ., data = dat_encoded)
unified_model <- ranger.unify(rf, dat_encoded)
shaps <- treeshap(unified_model, dat_encoded[1:50, ])

After ranger.unify, it throws "Warning: '/' not meaningful for factors". Then, after treeshap, it breaks and throws "Error in treeshap.model_unified(unified_model, dat_encoded[1:50, ]) : unified_model parameter has to of class model_unified. Produce it using *.unify function."

I'm using treeshap version 0.3.1 and ranger version 0.16.0.

mayer79 commented 3 months ago

@nlebovits I am not maintainer.

SHAP can't be calculated for non-probabilistic classification models. It requires numeric predictions, e.g., logit probabilities.

There is an open PR that allows probabilistic classification for ranger objects:

https://github.com/ModelOriented/treeshap/pull/43

Maybe you can test this?

library(remotes)

install_github("ModelOriented/treeshap", ref = github_pull("43"))