marcotcr / lime

Lime: Explaining the predictions of any machine learning classifier
BSD 2-Clause "Simplified" License
11.55k stars 1.8k forks source link

LimeImageExplainer, binary classification, KeyError: 'Label not in explanation' #634

Closed Enantiodromis closed 3 years ago

Enantiodromis commented 3 years ago

Hey!

I have been applying Lime to a number of models and have found it extremely insightful, thank you!

Recently I have encountered an error with a binary image classifier to which I am attempting to apply Lime's ImageExplainer to, I am not sure where I am going wrong, the error arises when I try and use any label != 0.

DEFINED MODEL

    model = Sequential()
        model.add(Conv2D(32, (3, 3), activation='relu', input_shape=(256, 256, 3)))
        model.add(MaxPool2D(2, 2))
        model.add(Conv2D(64, (3, 3), activation='relu'))
        model.add(MaxPool2D(2, 2))
        model.add(Conv2D(128, (3, 3), activation='relu'))
        model.add(MaxPool2D(2, 2))
        model.add(Conv2D(128, (3, 3), activation='relu'))
        model.add(MaxPool2D(2, 2))
        model.add(Flatten())
        model.add(Dense(512, activation='relu'))
        model.add(Dense(1, activation='sigmoid'))

        model.compile(loss='binary_crossentropy',
                    optimizer=RMSprop(learning_rate=1e-06),
                    metrics=['acc'])

        history = model.fit(
            train_generator,
            steps_per_epoch=(len(train_generator) / train_generator.batch_size),
            epochs=number_epochs,
            validation_data=test_generator,
            validation_steps=len(test_generator),
        )

TRAIN/TEST/VALID DATASET CREATION

  folder_path_train = 'datasets/image_data/image_data_1/real_vs_fake/train'
  folder_path_test = 'datasets/image_data/image_data_1/real_vs_fake/test'
  folder_path_valid = 'datasets/image_data/image_data_1/real_vs_fake/valid'

  train_generator = binary_dataset_creation(32, 256, 256, False, False, file_path=folder_path_train)
  test_generator = binary_dataset_creation(32, 256, 256, False, False, file_path=folder_path_test)
  valid_generator = binary_dataset_creation(32, 256, 256, False, False, file_path=folder_path_valid)

  X_test, y_test = next(test_generator)

GENERATOR LOGIC

generator = data_generator.flow_from_directory(
                directory = file_path,
                batch_size= batch_size,
                target_size= (img_height,img_width),
                class_mode= 'binary',
                shuffle=True)

LIME IMAGE EXPLAINER FUNCTION


        X_test_processed = [inc_net.preprocess_input(img) for img in X_test]
        y_test_processed = [label.astype(np.uint8) for label in y_test]

        def predict_fn(x):
            return model.predict_proba(x)

        # Create explainer 
        explainer = lime_image.LimeImageExplainer(verbose=False)
        from skimage.segmentation import mark_boundaries  
        random_indexes = random.sample(range(1,len(X_test)),3)

        for index in random_indexes:
            # Set up the explainer
            explanation = explainer.explain_instance(X_test[index].astype(np.float), predict_fn, top_labels = 2, hide_color = 0, num_samples = 1000)
            ati(X_test[index])

            labels = list(generator.class_indices)
            preds = model.predict(np.expand_dims(X_test_processed[index],axis=0))
            class_pred = model.predict_classes(np.expand_dims(X_test_processed[index],axis=0))[0][0]
            pct = np.max(preds, axis=-1)[0]

            print("LABELS ", labels)
            print("CLASS PRED ", class_pred)
            print("PREDICTION", preds)
            print(y_test_processed[index])
            print("TOP LABELS: ", explanation.top_labels)
            print("LOCAL PRED: ", explanation.local_pred)
            print("PCT:", pct)

            temp, mask = explanation.get_image_and_mask(0, positive_only=True , num_features=5, hide_rest=False)
            fig, (ax1, ax2, ax3) = plt.subplots(1,3, figsize = (10, 4))

            fig.suptitle('Classifier result: %r%% certainty of %r'%(round(pct,1)*100, labels[class_pred]))
            fig.tight_layout(h_pad=2)

            ax1.imshow(X_test_processed[index])
            ax1.set_title('Original Image')
            ax1.axis('off')

            ax2.imshow(mark_boundaries(temp, mask))
            ax2.set_title('Positive Regions for {}'.format(labels[class_pred]))
            ax1.axis('off')

            temp, mask = explanation.get_image_and_mask(0, positive_only=False, num_features=10, hide_rest=False)
            ax3.imshow(mark_boundaries(temp, mask))
            ax3.set_title('Positive & Negative Regions for {}'.format(labels[class_pred]))
            ax1.axis('off')

Thought I might add the output for the print statements for the Logic above for some insight.

LABELS  ['Cat', 'Dog']
CLASS PRED  0
PREDICTION [[0.39451957]]
0
TOP LABELS:  [0]
LOCAL PRED:  [0.40754287]
PCT: 0.39451957
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

Forgive the above bombardment, I have looked through to my knowledge all the related issues and have tried to implement any of the solutions discussed eg: https://github.com/marcotcr/lime/issues/165 I have spent a good 2 days trying to resolve this but as a final resort turning to here for any possible help.

Any help would be greatly appreciated!