sicara / tf-explain

Interpretability Methods for tf.keras models with Tensorflow 2.x
https://tf-explain.readthedocs.io
MIT License
1.01k stars 112 forks source link

Occlusion sensitivity doesn't give output #135

Open rao208 opened 4 years ago

rao208 commented 4 years ago

Hello folks,

It is not really a bug in tf_explain, but I am reaching out to the tf_explain community here because I really don't know who else to reach.

I am working on Synthetic MNIST dataset with image size = (64,64,3). The images are downloaded from Kaggle. These images were brightened and then sharpen (train, test and validation) before normalizing it (i.e. /255).

Original Image images

Final output

sharpen_image

Since the dataset doesn't follow the Gaussian distribution (used np.histogram to view the distribution of images), I avoided Standardization i.e. subtract mean and divide by standard deviation.

My CNN looks like this

sequntial_model

and the results are:

cm_plot_acc

plot_acc_loss

However, when I apply Occlusion Sensitivity on the data with patch size 20, I am not getting the expected output. What I mean is, when I apply os on 3 samples from class 0, let's say, then the heatmaps are on the same site.

test_data = ([x_test[sampleid]], None)

# Instantiation of the explainer
explainer = OcclusionSensitivity()

# Call to explain() method
output = explainer.explain(test_data, model, class_index = classidx, patch_size = 20, colormap = cv2.COLORMAP_JET)

class_0_1 class_0_2 class_0_3

It is true for all the classes. Does this mean that my CNN is not learning anything? It cannot be true because when I apply GradCAM, there is an output i.e. different heatmap location for different images from the same class. Or does this mean that this is the correct output? If so, then does this make sense to get the heatmaps on the same location on different samples of the same class?

Any help would be appreciated. Please help me because it is very important for my thesis. I have spent almost a month to figure this out.

If you need any further information, let me know

Best regards.

RaphaelMeudec commented 4 years ago

@rao208 You might want to reduce the patch size: a patch size of 20 means you only apply 3 patches along the x axis. Among the 9 patches, the bottom-right might be giving less information, hence the global red colormap. You might want to use a patch size of 5, or different patch size values (e.g [2, 5, 10]) to be able to compare multiple attribution maps.

rao208 commented 4 years ago

@rao208 You might want to reduce the patch size: a patch size of 20 means you only apply 3 patches along the x axis. Among the 9 patches, the bottom-right might be giving less information, hence the global red colormap. You might want to use a patch size of 5, or different patch size values (e.g [2, 5, 10]) to be able to compare multiple attribution maps.

Thank you for the quick response @RaphaelMeudec. Even with the different patch sizes, the heatmap location is the same. Here are the results:

Patch size 5

class_0_3_patch_size5 class_0_1_patch_size5 class_0_2_patch_size5

Patch Size 10

class_0_1_patch_size10 class_0_2_patch_size10 class_0_3_patch_size10

Is there any problem with how the test_data is given to the occlusion sensitivity? (The code is attached in the question above)

I just can't figure out what could be the cause. I worked with cifar10 as well and I see the similar pattern there too i.e. i.e. different heatmap location for different images from the same class.

RaphaelMeudec commented 4 years ago

Could you provide the link to the dataset and the training script? (in particular the preprocessing you apply to the images)

rao208 commented 4 years ago

@RaphaelMeudec The link to the dataset in Kaggle is https://www.kaggle.com/prasunroy/synthetic-digits?

There are two folders: imgs_train and imgs_valid each of these contains 10 more folders. I, first, put all the training digits in the 'train' folder and testing digits in the 'test' folder. Later, converted them into .npy file. I tried to convert the folder into the .zip file, but it was too big to upload here. Nevertheless, you can access the .npy file from my drive (https://drive.google.com/drive/u/2/folders/1rjQ0CjaiiNcuHhXJptlvs9u-Fog6LAoW).

Please let me know if you are unable to open the link or download the files

The code is

# -*- coding: utf-8 -*-
"""
Created on Mon May 11 16:31:50 2020

@author: Vanditha Rao
"""

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, BatchNormalization
from sklearn.metrics import classification_report, confusion_matrix
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras import optimizers
from sklearn.model_selection import train_test_split
from skimage import exposure
from PIL import Image
from PIL import ImageEnhance
from tensorflow.keras import regularizers
tf.keras.backend.clear_session()

with tf.device('/device:GPU:0'):

    class synthetic_mnist_model:

        def __init__(self):

            self.x_shape = [64,64,3]
            self.batch_size = 64
            self.maxepoches = 50
            self.num_classes = 10
            self.weight_decay = 0.0005
            self.model = self.build_model()

        def plot_confusion_matrix(self, y_true, y_pred, classes, title = None, cmap=plt.cm.Blues):

            """
            This function prints and plots the confusion matrix.
            Normalization can be applied by setting `normalize=True`.
            """

            # Compute confusion matrix

            cm = confusion_matrix(y_true, y_pred)

            # Only use the labels that appear in the data

            fig, ax = plt.subplots(figsize=(10,10))

            im = ax.imshow(cm, interpolation='nearest', cmap=cmap)
            ax.figure.colorbar(im, ax=ax)

            # We want to show all ticks...

            ax.set(xticks=np.arange(cm.shape[1]),
                    yticks=np.arange(cm.shape[0]),
                    # ... and label them with the respective list entries
                    xticklabels=classes, yticklabels=classes,
                    title=title,
                    ylabel='True label',
                    xlabel='Predicted label')

            # Rotate the tick labels and set their alignment.
            plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
                      rotation_mode="anchor")

            # Loop over data dimensions and create text annotations.
            fmt = 'd'     
            thresh = cm.max() / 2.0

            for i in range(cm.shape[0]):
                for j in range(cm.shape[1]):
                    ax.text(j, i, format(cm[i, j], fmt),
                            ha="center", va="center",
                            color="white" if cm[i, j] > thresh else "black")

            ax.set_ylim(len(cm)-0.5, -0.5)

            fig.tight_layout()
            return ax

        def load_data(self):

            test_x = np.load('./data/test_x_64.npy')
            test_y = np.load('./data/test_y_64.npy')

            train_x = np.load('./data/train_x_64.npy')
            train_y = np.load('./data/train_y_64.npy')

            train_x = train_x.astype('float32')
            test_x = test_x.astype('float32')

            return test_x, test_y, train_x, train_y

        def change_brightness(self,img, brightness = 1.2):

            enh_bri = ImageEnhance.Brightness(img)
            image_brightened = enh_bri.enhance(brightness)

            return image_brightened

        def brightness(self, X_train, X_test, X_val):

            bright_train = np.zeros(X_train.shape)
            bright_test = np.zeros(X_test.shape)
            bright_val = np.zeros(X_val.shape)

            for i in range(X_train.shape[0]):
                image = Image.fromarray(X_train[i, :, :, :].astype(np.uint8))
                bright_train[i, :, :, :] = self.change_brightness(image)

            bright_train = bright_train.astype('float32')

            for i in range(X_test.shape[0]):
                image = Image.fromarray(X_test[i, :, :, :].astype(np.uint8))
                bright_test[i, :, :, :] = self.change_brightness(image)

            bright_test = bright_test.astype('float32')

            for i in range(X_val.shape[0]):
                image = Image.fromarray(X_val[i, :, :, :].astype(np.uint8))
                bright_val[i, :, :, :] = self.change_brightness(image)

            bright_val = bright_val.astype('float32')

            return bright_train, bright_test, bright_val

        def change_sharpness(self,img, sharpness = 2.0):
            enh_sha = ImageEnhance.Sharpness(img)
            image_sharped = enh_sha.enhance(sharpness)

            return image_sharped

        def sharpness(self, X_train, X_test, X_val):

            sharp_train = np.zeros(X_train.shape)
            sharp_test = np.zeros(X_test.shape)
            sharp_val = np.zeros(X_val.shape)

            for i in range(X_train.shape[0]):
                image = Image.fromarray(X_train[i, :, :, :].astype(np.uint8))
                sharp_train[i, :, :, :] = self.change_sharpness(image)

            sharp_train = sharp_train.astype('float32')

            for i in range(X_test.shape[0]):
                image = Image.fromarray(X_test[i, :, :, :].astype(np.uint8))
                sharp_test[i, :, :, :] = self.change_sharpness(image)

            sharp_test = sharp_test.astype('float32')

            for i in range(X_val.shape[0]):
                image = Image.fromarray(X_val[i, :, :, :].astype(np.uint8))
                sharp_val[i, :, :, :] = self.change_sharpness(image)

            sharp_val = sharp_val.astype('float32')

            return sharp_train, sharp_test, sharp_val

        def one_hot_encode(self, Y_train, Y_test, Y_val):

            Y_train = tf.keras.utils.to_categorical(Y_train, self.num_classes)
            Y_test = tf.keras.utils.to_categorical(Y_test, self.num_classes)
            Y_val = tf.keras.utils.to_categorical(Y_val, self.num_classes)

            return Y_train, Y_test, Y_val

        def predict(self, x):
            return self.model.predict(x, self.batch_size)

        def evaluate(self,x,y):
            return self.model.evaluate(x,y, self.batch_size, verbose=2)

        def build_model(self):

            model = Sequential()

            weight_decay = 0.001

            model.add(Conv2D(16, kernel_size= (4,4),
                             input_shape = self.x_shape,
                             padding = 'same',
                             activation = "relu",
                             # kernel_initializer='he_normal'
                             kernel_regularizer=regularizers.l2(weight_decay)
                             ))

            model.add(BatchNormalization())

            model.add(Conv2D(16, kernel_size= (4,4),
                             padding = 'same',
                             # activation= tf.nn.leaky_relu,))
                             activation = "relu",
                             # kernel_initializer='he_normal'
                             kernel_regularizer=regularizers.l2(weight_decay)
                             ))

            model.add(BatchNormalization())
            model.add(MaxPooling2D(pool_size=(2,2)))
            # model.add(Dropout(0.2))

            model.add(Conv2D(32, kernel_size=(3,3),
                             padding = 'same',
                             activation = "relu",
                             # kernel_initializer='he_normal'
                             kernel_regularizer=regularizers.l2(weight_decay)
                             ))
            model.add(BatchNormalization())

            model.add(Conv2D(32, kernel_size=(3,3),
                             padding = 'same',
                             activation = "relu",
                             # kernel_initializer='he_normal'))
                             kernel_regularizer=regularizers.l2(weight_decay)))
            model.add(BatchNormalization())
            model.add(MaxPooling2D(pool_size=(2,2)))
            # model.add(Dropout(0.3))

            model.add(Conv2D(64, kernel_size=(3,3),
                              padding = 'same',
                              activation = "relu",
                              kernel_regularizer=regularizers.l2(weight_decay)
                              ))
            model.add(BatchNormalization())

            model.add(Conv2D(64, kernel_size=(3,3),
                              padding = 'same',
                              activation = "relu",
                              kernel_regularizer=regularizers.l2(weight_decay)))

            model.add(BatchNormalization())
            model.add(MaxPooling2D(pool_size=(2,2)))

            model.add(Conv2D(128, kernel_size=(3,3),
                              padding = 'same',
                              activation = "relu",
                              kernel_regularizer=regularizers.l2(weight_decay)
                              ))
            model.add(BatchNormalization())

            model.add(Conv2D(128, kernel_size=(3,3),
                              padding = 'same',
                              activation = "relu",
                              kernel_regularizer=regularizers.l2(weight_decay)))

            model.add(BatchNormalization())
            model.add(MaxPooling2D(pool_size=(2,2)))

            model.add(Flatten())

            model.add(Dense(1152, activation = "relu",
                            # kernel_regularizer=regularizers.l2(weight_decay)
                            ))

            model.add(Dropout(0.5))
            model.add(Dense(1152, activation = "relu",
                            # kernel_regularizer=regularizers.l2(weight_decay)
                            ))

            model.add(Dropout(0.5))

            model.add(Dense(10, activation= "softmax"))
            model.summary()

            return model

        def train(self, train_x, train_y, val_x, val_y):

            batch_size = 64

            datagen = ImageDataGenerator(rotation_range=15,
                                         width_shift_range=0.1,
                                         height_shift_range=0.1,
                                         horizontal_flip=True,
                                         )

            datagen.fit(train_x)

            self.model.compile(loss='categorical_crossentropy',
                               optimizer=optimizers.Adadelta(lr=1),
                               metrics=['accuracy'])

            history = self.model.fit(datagen.flow(train_x, train_y, batch_size=batch_size),
                                steps_per_epoch=train_x.shape[0] // batch_size,
                                epochs=self.maxepoches,
                                validation_data=(val_x, val_y),
                                verbose=2)

            return history    

        def save(self):

            model_json = self.model.to_json()

            self.model.save_weights('./model/model_weights_synthetic_mnist_os_bright_sharp_gap.h5')
            self.model.save('./model/model_synthetic_mnist_os_bright_sharp_gap.h5')

            with open('./model/model_synthetic_mnist_os_bright_sharp_gap.json', 'w') as json_file:
                json_file.write(model_json)

        def plot(self, history):

            train_acc = history.history['accuracy']
            validation_acc = history.history['val_accuracy']

            train_loss = history.history['loss']
            validation_loss = history.history['val_loss']

            plt.figure(figsize=(10,10))

            plt.subplot(1,2,1)          
            plt.plot(train_acc,'r',label='Training Accuracy')
            plt.plot(validation_acc,'b',label='Validation Accuracy')

            plt.title('Training and Validation Accuracy')
            plt.legend()

            plt.subplot(1,2,2)

            plt.plot(train_loss,'r',label='Training loss')
            plt.plot(validation_loss,'b',label='Validation loss')

            plt.title('Training and Validation loss')
            plt.legend()
            plt.show()

    if __name__ == '__main__':

        sm = synthetic_mnist_model()

        # load the dataset

        test_x, test_y, train_x, train_y = sm.load_data()

        # split the training set into training and validation

        train_x, val_x, train_y, val_y = train_test_split(train_x, train_y,
                                                          test_size=0.2,
                                                          random_state=1234,
                                                          shuffle = True,
                                                          stratify=train_y
                                                          )
        print(train_x.shape)
        print(test_x.shape)
        print(val_x.shape)

        # Plot the training images

        fig = plt.figure(figsize=(5,4))

        for i in range(3):
            for j in range(3):
                ax = fig.add_subplot(3, 3, i * 3 + j + 1)
                ax.imshow(train_x[i * 3 + j]/255)

        plt.show()

        # brighten the images

        train_x, test_x, val_x = sm.brightness(train_x, test_x, val_x)

        # plot brighten images

        fig1 = plt.figure(figsize=(5,4))
        for i in range(3):
            for j in range(3):
                ax1 = fig1.add_subplot(3, 3, i * 3 + j + 1)
                ax1.imshow(train_x[i * 3 + j]/255)

        plt.show()    

        # sharpen the images 

        train_x, test_x, val_x = sm.sharpness(train_x, test_x, val_x)

        # normalize the image

        train_x /=255
        test_x /=255
        val_x /=255

        # view sharp images 
        fig2 = plt.figure(figsize=(5,4))
        for i in range(3):
            for j in range(3):

                ax2 = fig2.add_subplot(3, 3, i * 3 + j + 1)
                ax2.imshow(train_x[i * 3 + j])

        plt.show()

        # one hot encode

        train_y, test_y, val_y = sm.one_hot_encode(train_y, test_y, val_y)

        # train the model
        history = sm.train(train_x, train_y, val_x, val_y)

        # save the model

        sm.save()

        # Plot accuracy and loss

        sm.plot(history)

        # predict

        predicted_x = sm.predict(test_x)
        residuals = np.argmax(predicted_x,1)!=np.argmax(test_y,1)

        loss = sum(residuals)/len(residuals)
        print("the validation 0/1 loss is: ",loss)

        # evaluate on test dataset

        loss, acc = sm.evaluate(test_x, test_y)
        print('Test Accuracy: %.3f' % (acc * 100))

        # plot confusion matrix and print classification report
        classes = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
        labels = [0,1,2,3,4,5,6,7,8,9]

        test_y =  test_y.argmax(axis=1)
        y_pred = predicted_x.argmax(axis=1)

        print(classification_report(test_y, y_pred, target_names= classes))

        np.set_printoptions(precision=2)

        ## Plot non-normalized confusion matrix

        sm.plot_confusion_matrix(test_y, y_pred, classes=classes, title='Confusion matrix', cmap=plt.cm.Blues)
        plt.show()
rao208 commented 4 years ago

@RaphaelMeudec

Update:

I have observed a similar pattern when I use Albumentation image augmentation technique.

rao208 commented 4 years ago

@RaphaelMeudec Please help me. Do you think there is any bug in my code? I tried different preprocessing techniques (like albumentation data augmentation, standardization, normalization, brightening and sharpening the image)

rao208 commented 4 years ago

@RaphaelMeudec What is the use of grid_display function? I went through your code on occlusion sensitivity. I get the use of everything except for the grid_display function. I was wondering what is the significance of that function? What if we do not use that function?

matheushent commented 4 years ago

@RaphaelMeudec What is the use of grid_display function? I went through your code on occlusion sensitivity. I get the use of everything except for the grid_display function. I was wondering what is the significance of that function? What if we do not use that function?

@rao208 I explain a little bit about _griddisplay here and how you can not use it. Hope it helps you.