ModelOriented / DALEX

moDel Agnostic Language for Exploration and eXplanation
https://dalex.drwhy.ai
GNU General Public License v3.0
1.38k stars 166 forks source link

Does DALEX work with PyTorch models? #565

Closed yanghuikang closed 5 months ago

yanghuikang commented 5 months ago

I've been using DALEX to make ALE profiles for xgboost models, and recently want to test neural networks. I see that tensorflow is supported and wonder if there is a way to use it with PyTorch.

hbaniecki commented 5 months ago

Hi, dalex will work with any model for which you can pass a predict_function to dx.Explainer. This function inputs a model with data as ~np.ndarray~ pd.DataFrame and returns a 1-dimensional np.ndarray with model predictions. Examples of such predict functions: https://github.com/ModelOriented/DALEX/blob/master/python/dalex/dalex/_explainer/yhat.py

Notably, for tensorflow, it looks like this:

def yhat_tf_regression(m, d):
    return m.predict(np.array(d), verbose=0).reshape(-1, )

def yhat_tf_classification(m, d):
    return m.predict(np.array(d), verbose=0)[:, 1]
yanghuikang commented 5 months ago

Thank you very much! I will test it with PyTorch.

hbaniecki commented 5 months ago

reopen if needed