shap / shap

A game theoretic approach to explain the output of any machine learning model.
https://shap.readthedocs.io
MIT License
22.89k stars 3.29k forks source link

Shap not working with bert fine tuned multi text classification problem. #2629

Open loni9164 opened 2 years ago

loni9164 commented 2 years ago

HI, I am trying to run shap explainer over the bert fine tunned multitext classifier model to debug the model. It throwing an error. Could you please look into notebook https://github.com/loni9164/Text-explainer-issues/blob/main/Bert-shap.ipynb and help to me fix the issues?

Achinth04 commented 1 year ago

HI, I am trying to run shap explainer over the bert fine tunned multitext classifier model to debug the model. It throwing an error. Could you please look into notebook https://github.com/loni9164/Text-explainer-issues/blob/main/Bert-shap.ipynb and help to me fix the issues?

i am facing the same issue as yours i am using a LayoutLM model for token classificationwhich is pretty similar to bert architecturally my code also says that 'AssertionError: The model produced 52 output rows when given 1 input rows! Check the implementation of the model you provided for errors.' this is my code model is a layoutLM for token classification model and tokensier is also the LayoutLM tokeniser

def predictor(x):
    outputs = model(**tokenizer(x, return_tensors="pt", padding=True))
    probas = f.softmax(outputs.logits, dim=1).detach().numpy()
    #val = sp.special.logit(probas[:,1])
    print(probas)
    return probas

def f_batch(x):
    val = np.array([])
    for i in x:
      val = np.append(val, predictor(i))
    return val
s=['i am']
explainer_bert = shap.Explainer(f_batch, tokenizer)
#test = {'label': [1], 'text': ['this is a lovely movie']}
Shapvalue=explainer_bert(s)

this is the inputs and outputs for which causes the error: inputs (array(['[MASK]'], dtype='<U6'),) outputs [0.30509463 0.38688242 0.35780382 0.34003907 0.25436756 0.35160705 0.23038684 0.31761795 0.22902177 0.2792168 0.33950776 0.28346184
0.37673852 0.38966191 0.22607958 0.28429452 0.32000664 0.49130982
0.29692563 0.53920555 0.36467054 0.54198849 0.4416059 0.32091084
0.43298629 0.24658103 0.30524346 0.38703796 0.3579016 0.33995429
0.25432256 0.35146728 0.23040767 0.31771147 0.22898977 0.27917728
0.3395814 0.2835519 0.37668049] Traceback (most recent call last):