marcotcr / lime

Lime: Explaining the predictions of any machine learning classifier
BSD 2-Clause "Simplified" License
11.53k stars 1.8k forks source link

Explanation of Input-Text in character-CNN implemented in Keras #162

Closed Goschjann closed 6 years ago

Goschjann commented 6 years ago

Dear Lime-Team,

for a university project I am working on text classification with the famous Character-CNN that I implemented in Keras. I classify restaurant reviews in a positive or a negative sentiment and built a small app that accepts any user-given text-review. This currently yields reasonable results.

Now I would like to understand why the charCNN makes which prediction - a perfect use case for LIME imho.

Therefore, I use the following python3 code:

import lime.lime_text as lt
import keras
# my self written library
import projectlib as pl
import numpy as np

# This is the text that I want to explain
inputText = str("I did not like the food and the drinks were expensive!")

# read alphabet for character-vectorization of the input text
alphabetPath = "/home/jgucci/Desktop/uni/text_mining/tm_data/alphabet.txt"
alphabet = open(alphabetPath).read()
maxChars = 1014

#
recodeText = pl.generate_one_hot(text=inputText, alphabet=alphabet, maxChars=maxChars)

model = keras.models.load_model("charCnn_6_polarity.h5")

# pipeline-like function
# takes raw text as input, converts it to one-hot character vectors via the alphabet
# feeds it into keras' model.predict() function to receive predictions
# works for lists of text as well as for single strings
def predictFromText(textInputList):

    # catch single string inputs and convert them to list
    if textInputList.__class__ != list:
        textInputList = [textInputList]
        print("caught single string")
    # list for predictions
    predStorage = []
    # loop through input list and predict
    for textInput in textInputList:

        recodeText = pl.generate_one_hot(text=textInput, alphabet=alphabet, maxChars=maxChars)
        pred = model.predict(recodeText.transpose())
        # control output of function
        print(str(textInput), "\n", pred)
        predStorage.append(pred)

    return(np.asarray(predStorage))

# this works, yields an array with probabilities for both classes
print(predictFromText(textInputList = listTexts))
print(predictFromText(textInputList=inputText))

# Lime Explainer
# bow controls if words are perturbed or overwritten with UNKWORDZ
# False makes sense, if location of words is important as in this classifier
explainer = lt.LimeTextExplainer(kernel_width=25, verbose=True, class_names=[0, 1],
                           feature_selection="auto", split_expression=" ", bow=False)

exp = explainer.explain_instance(text_instance=inputText, labels=[0, 1],
                     classifier_fn=predictFromText, num_features=5, num_samples=100)

and after execution with verbose = True I receive the output from the function explain_instance_with_data()

Intercept [ 0.31570154  0.68429844]
Prediction_local [[-0.0186873  1.0186873]]
Right: [ 0.00305071  0.99694926]

and the error:

Traceback (most recent call last):
  File "/usr/local/lib/python3.5/dist-packages/IPython/core/interactiveshell.py", line 2862, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-94-c405d2d113a4>", line 2, in <module>
    classifier_fn=predictFromText, num_features=5, num_samples=100)
  File "/usr/local/lib/python3.5/dist-packages/lime/lime_text.py", line 281, in explain_instance
    feature_selection=self.feature_selection)
  File "/usr/local/lib/python3.5/dist-packages/lime/lime_base.py", line 177, in explain_instance_with_data
    key=lambda x: np.abs(x[1]), reverse=True),
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

Is it problematic..

  1. ... that I use a selfmade pipeline (predictFromText()) instead of sklearn's make_pipeline() function?
  2. ... that my model embeds the text on a character-level and not on a word-level?

I would appreciate any help or hint a lot!

Goschjann commented 6 years ago

Found the error: my pipeline function did not output a proper ndarray, everthing works fine - awesome package!

marcotcr commented 6 years ago

: )