predict-idlab / powershap

A power-full Shapley feature selection method.
Other
193 stars 18 forks source link

Missing XGBoost support #28

Closed mattharrison closed 1 year ago

mattharrison commented 1 year ago

Here is my code to get XGBoost working (from a notebook):

from powershap import PowerShap
import powershap.shap_wrappers.shap_explainer as se
import powershap.shap_wrappers.shap_explainer_factory as sef
import shap
import xgboost as xgb
from typing import Callable

class XGBoostExplainer(se.ShapExplainer):
    @staticmethod
    def supports_model(model) -> bool:
        supported_models = [xgb.XGBRegressor, xgb.XGBClassifier]
        return isinstance(model, tuple(supported_models))

    def _validate_data(self, validate_data: Callable, X, y, **kwargs):
        kwargs["force_all_finite"] = False  # xgboost allows NaNs and infs in X
        kwargs["dtype"] = None  # allow non-numeric data
        return super()._validate_data(validate_data, X, y, **kwargs)

    def _fit_get_shap(
        self, X_train, Y_train, X_val, Y_val, random_seed, **kwargs
    ) -> np.array:
        # Fit the model
        params = self.model.get_params()
        #PowerShap_model = self.model.copy().set_params(random_seed=random_seed)
        PowerShap_model = self.model.__class__(**{**params, 'random_seed':random_seed})
        PowerShap_model.fit(X_train, Y_train)#, eval_set=(X_val, Y_val))
        # Calculate the shap values
        C_explainer = shap.TreeExplainer(PowerShap_model)
        return C_explainer.shap_values(X_val)

    def _get_more_tags(self):
        return {"allow_nan": True}

sef.ShapExplainerFactory._explainer_models.append(XGBoostExplainer)    
jvdd commented 1 year ago

Hey @mattharrison,

It would be great to see support for XGBoost added to this library :+1: If you'd like to contribute the code for this, I'd be happy to review a pull request. If not, I can submit a pull request myself.

Thanks, Jeroen

mattharrison commented 1 year ago

Thanks so much for the new version!

jvdd commented 1 year ago

Your're welcome! Thanks for providing the code snippet ;)

powerhsap v0.9.0 should now support XGBoost models :tada: