marcotcr / lime

Lime: Explaining the predictions of any machine learning classifier
BSD 2-Clause "Simplified" License
11.4k stars 1.79k forks source link

Trying to explaine a simple Neural Network using LIME #733

Open omaruno opened 7 months ago

omaruno commented 7 months ago

I am trying to explaine the following neural network trained on toy data

import keras
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from lime import lime_tabular
from lime.lime_tabular import LimeTabularExplainer

X = np.array([[(1,2,3,3,1),(3,2,1,3,2),(3,2,2,3,3),(2,2,1,1,2),(2,1,1,1,1)],
              [(4,5,6,4,4),(5,6,4,3,2),(5,5,6,1,3),(3,3,3,2,2),(2,3,3,2,1)],
              [(7,8,9,4,7),(7,7,6,7,8),(5,8,7,8,8),(6,7,6,7,8),(5,7,6,6,6)],
              [(7,8,9,8,6),(6,6,7,8,6),(8,7,8,8,8),(8,6,7,8,7),(8,6,7,8,8)],
              [(4,5,6,5,5),(5,5,5,6,4),(6,5,5,5,6),(4,4,3,3,3),(5,5,4,4,5)],
              [(4,5,6,5,5),(5,5,5,6,4),(6,5,5,5,6),(4,4,3,3,3),(5,5,4,4,5)],
              [(1,2,3,3,1),(3,2,1,3,2),(3,2,2,3,3),(2,2,1,1,2),(2,1,1,1,1)]])
y = np.array([0, 1, 2, 2, 1,1,0])

model = keras.Sequential([
    layers.LSTM(64, return_sequences=True, input_shape=(5, 5)),
    layers.Conv1D(64, kernel_size=3, activation='relu'),
    layers.MaxPooling1D(pool_size=2),
    layers.Flatten(),
    layers.Dense(128, activation='relu'),
    layers.Dense(3, activation='softmax')  # Adjust the number of output units based on your problem (3 for 3 classes)
])

model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.fit(X,y,epochs=10)

I want to use LIME to identify the most relevant feature and Iam using the following code

# Create a LimeTabularExplainer
explainer = LimeTabularExplainer(training_data=X, mode="classification", feature_names=[f"feature_{i}" for i in range(5)])

# Choose a specific instance for which you want to explain the prediction
instance_to_explain = X[0]

# Generate an explanation for the instance
explanation = explainer.explain_instance(instance_to_explain.flatten(), model.predict, num_features=5)

# Print the explanation
print(explanation.as_list())

But it keep giving me error that Iam not understanding.