raghakot / keras-vis

Neural network visualization toolkit for keras
https://raghakot.github.io/keras-vis
MIT License
2.98k stars 661 forks source link

Switching softmax to linear makes the contrast work more difficult #196

Open jiayi9 opened 5 years ago

jiayi9 commented 5 years ago

If we could reload the MNIST model trained in the example and draw all the saliency charts for for one digit (e.g. '7') using filter index from zero to ten, we can see that it's very difficult to observe the contrast using the linear activation. Yet, the contrast can indeed be seen using softmax activation. Moreover, we need to manually set a consistent scale to see the contrast.

It may be important, at least for defect detection topics, to know how and why an image is classified as one thing but not another by checking the significant contrasts in saliency visualizations.

from __future__ import print_function
from keras.datasets import mnist
import numpy as np
from matplotlib import pyplot as plt
from vis.visualization import visualize_saliency
from vis.utils import utils
from keras import activations
from keras.models import load_model

(x_train, y_train), (x_test, y_test) = mnist.load_data()

# Load once for a model with softmax for last dense layer, and load again for one with linear swap

MODEL_PATH = "model.h5"

model = load_model(MODEL_PATH)
raw_model = load_model(MODEL_PATH)

# check layers in the model
NAMES = []
for index, layer in enumerate(model.layers):
    NAMES.append(layer.name)
    print(index, layer.name)
print('====================================================\n\n\n')

# swap softmax
layer_idx = utils.find_layer_idx(model, 'dense_2')
model.layers[layer_idx].activation = activations.linear
model = utils.apply_modifications(model)

# prepare a sample image '7'
img = x_test[0]/255

seed = img.copy()
seed = np.expand_dims(seed, 2)
seed = np.expand_dims(seed, 0)

# use absolute scale for the 
MAX_PIXEL_softmax = 0.01
MAX_PIXEL_linear = 1

for index in range(10):
    print('----------------------------------------------')
    print('Digit: ', index)
    f, ax = plt.subplots(1, 3)

    grads_softmax = visualize_saliency(raw_model, layer_idx, filter_indices=index,
                               seed_input=seed, backprop_modifier="guided")
    print('total:', round(grads_softmax.sum()*10000), '  max:', round(grads_softmax.max(),5), '  min:', round(grads_softmax.min(),5))
    grads_softmax[0,0] = MAX_PIXEL_softmax
    ax[0].set_title('Softmax ' + str(index))
    ax[0].imshow(grads_softmax, cmap = 'jet')

    grads_linear = visualize_saliency(model, layer_idx, filter_indices=index,
                               seed_input=seed, backprop_modifier="guided")
    print('total:', round(grads_linear.sum()), '  max:', round(grads_linear.max(),5), '  min:', round(grads_linear.min(),5))
    grads_linear[0,0] = MAX_PIXEL_linear
    ax[1].set_title('Linear ' + str(index))
    ax[1].imshow(grads_linear, cmap = 'jet')

    ax[2].set_title('Raw image')
    ax[2].imshow(img)