dreamquark-ai / tabnet

PyTorch implementation of TabNet paper : https://arxiv.org/pdf/1908.07442.pdf
https://dreamquark-ai.github.io/tabnet/
MIT License
2.55k stars 470 forks source link

TabNetClassifier explainability #500

Closed ranellout closed 8 months ago

ranellout commented 11 months ago

Hello First of all I want to thank you for the awesome package, I've had some really nice results with the TabNetClassifier!

I wanted to ask - Is there a way to interpret the TabNetClassifier feature importance with SHAP values or some other nice visualization package? I was able to visualize the feature importance and it's nice but it doesn't tell the direction of prediction contribution for each feature. The local explainability with the masks didn't worked for me at all.. Thanks!!

The code I've tried for SHAP visualization:

background_adult = shap.maskers.Independent(X_valid, max_samples=100) explainer = shap.Explainer(clf.predict_proba, background_adult) shap_values = explainer(X_valid[:100]) shap.plots.beeswarm(shap_values)

The above code raised the following error: "The passed model is not callable and cannot be analyzed directly with the given masker!"

other visualization for SHAP didn't work also.

I'm not looking for SHAP implementation specifically, rather than some sort of visualization for the feature importance.

Thank you!

Optimox commented 11 months ago

what happens if you change the explainer to this : explainer = shap.Explainer(clf, background_adult)

ranellout commented 11 months ago

unfortunately, it still raises the same error

mahimairaja commented 11 months ago

Is the issue still persists?

Optimox commented 11 months ago

As it's independent from the repo it will probably persist quite a while!