I tested the conservation property for some LRP rules, and all results seem okay, except those for the alpha beta rule. The maximum relative errors are:
LRP alpha=2 beta=1, ignore bias: 0.16589043
LRP alpha=2 beta=1, with bias: 0.8877983
If my understanding is correct, the rule is not supposed to be conservative, when biases are not ignored, because they will get some relevance that is missing from the heatmap. However, I think in the other case, the rule should be conservative.
Qualitatively, the alpha beta heatmaps look quite different to those on heatmapping.org.
This is the code I used:
%matplotlib inline
%load_ext autoreload
%autoreload 2
import keras
from keras.models import Sequential, Model
from keras.layers import Dense, Dropout, Activation, Flatten
from keras.layers import Conv2D, MaxPooling2D
from keras.datasets import mnist
from matplotlib import pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
import numpy as np
import pdb
import innvestigate
def get_cnn_no_dropout():
model = Sequential()
target_shape = X_train.shape[1:] if len(X_train.shape)==4 else list(X_train.shape)[1:] + [1]
model.add(keras.layers.Reshape(target_shape, input_shape=X_train.shape[1:]))
model.add(Conv2D(32, (3, 3), padding='same', activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(32, (3, 3), padding='same', activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Flatten())
model.add(Dense(512, activation='relu'))
model.add(Dense(Y_train.shape[1]))
model.add(Activation('softmax'))
opt = keras.optimizers.Adam()
model.compile(loss='categorical_crossentropy',
optimizer=opt,
metrics=['accuracy'])
return model
def get_trained_cnn_no_dropout():
cnn = get_cnn_no_dropout()
try:
cnn.load_weights("cnn_no_dr.hdf5")
except Exception as e:
print("Loading failed: ", type(e), e)
cnn.fit(X_train, Y_train,
batch_size=100,
validation_split=1/6,
epochs=20)
cnn.save("cnn_no_dr.hdf5")
acc = cnn.evaluate(X_test, Y_test)
print(acc)
return cnn
cnn = get_trained_cnn_no_dropout()
def cut_last_layer(model, n=1):
return Model(inputs=model.inputs, outputs=model.layers[-1-n].output)
def heatmaps(X, Y=None):
X = X.reshape([-1, 28, 28])
rows, cols = int(np.sqrt(len(X))), int(np.ceil(np.sqrt(len(X))))
fig, axes = plt.subplots(nrows=rows, ncols=cols, squeeze=False,
figsize=(cols,rows,))
for i, x in enumerate(X):
row = int(i/cols)
col = i % cols
ax = axes[row][col]
if Y is not None:
y = Y[i]
ax.set_title(np.argmax(y))
ax.tick_params(
which='both', # both major and minor ticks are affected
bottom=False, # ticks along the bottom edge are off
top=False, # ticks along the top edge are off
left=False,
right=False,
labelbottom=False,
labelleft=False) # labels along the bottom edge are off
img = np.ones([28,28,3])
red = np.where(x>0,x,0)
blue = np.where(x>0,0,-x)
img[:,:,0] -= 4*blue
img[:,:,1] -= 4*(blue + red)
img[:,:,2] -= 4*red
ax.imshow(img)
plt.tight_layout()
plt.show()
lrpz = innvestigate.create_analyzer("lrp.z_IB", cut_last_layer(cnn))
h_lrpz = lrpz.analyze(X_test[:16])
lrpe = innvestigate.create_analyzer("lrp.epsilon_IB", cut_last_layer(cnn), epsilon=1)
h_lrpe = lrpe.analyze(X_test[:16])
itg = innvestigate.create_analyzer("input_t_gradient", cut_last_layer(cnn))
h_itg = itg.analyze(X_test[:16])
def relu(a):
return np.where(a>0, a, 0)
def conservation(H, Y):
R = H.sum(axis=(1,2,))
return np.abs((R - Y)/(Y+1e-9)).max()
m = cut_last_layer(cnn)
Y = m.predict(X_test[:16]).max(axis=1)
ab = innvestigate.create_analyzer("lrp.alpha_2_beta_1_IB", cut_last_layer(cnn))
h_ab = ab.analyze(X_test[:16])
abb = innvestigate.create_analyzer("lrp.alpha_2_beta_1", cut_last_layer(cnn))
h_abb = abb.analyze(X_test[:16])
print("LRP alpha=2 beta=1, ignore bias: ", conservation(h_ab, Y))
print("LRP alpha=2 beta=1, with bias: ", conservation(h_abb, Y))
print("LRP Deeptaylor bounded: ", conservation(h_dt, relu(Y)))
print("LRP Epsilon=1", conservation(h_lrpe, Y))
print("LRP Z", conservation(h_lrpz, Y))
I tested the conservation property for some LRP rules, and all results seem okay, except those for the alpha beta rule. The maximum relative errors are:
This is the code I used:
And the output:
This is the plot of the LRP.alpha_2_beta_1 heatmaps