marcoancona / DeepExplain

A unified framework of perturbation and gradient-based attribution methods for Deep Neural Networks interpretability. DeepExplain also includes support for Shapley Values sampling. (ICLR 2018)
https://arxiv.org/abs/1711.06104
MIT License
735 stars 133 forks source link

Error with MNIST example for DeepLift attributions #11

Closed donigian closed 6 years ago

donigian commented 6 years ago

Hi, When attempting to run the MNIST example for DeepLift, I get the following error:

InvalidArgumentError: You must feed a value for placeholder tensor 'conv2d_1_input' with dtype float and shape [?,28,28,1]
     [[Node: conv2d_1_input = Placeholder[dtype=DT_FLOAT, shape=[?,28,28,1], _device="/job:localhost/replica:0/task:0/device:GPU:0"]()]]

During handling of the above exception, another exception occurred:

Here is the snippet of code I'm running:

with DeepExplain(session=K.get_session()) as de:  # <-- init DeepExplain context
    # Need to reconstruct the graph in DeepExplain context, using the same weights.
    # With Keras this is very easy:
    # 1. Get the input tensor to the original model
    input_tensor = model.layers[0].input

    # 2. We now target the output of the last dense layer (pre-softmax)
    # To do so, create a new model sharing the same layers untill the last dense (index -2)
    fModel = Model(inputs=input_tensor, outputs = model.layers[-2].output)
    target_tensor = fModel(input_tensor)

    xs = x_test[0:10]
    ys = y_test[0:10]

#     attributions = de.explain('grad*input', target_tensor * ys, input_tensor, xs)
    #attributions = de.explain('saliency', target_tensor * ys, input_tensor, xs)
#     attributions = de.explain('intgrad', target_tensor * ys, input_tensor, xs)
    attributions = de.explain('deeplift', target_tensor * ys, input_tensor, xs)
#     attributions = de.explain('elrp', target_tensor * ys, input_tensor, xs)
#     attributions = de.explain('occlusion', target_tensor * ys, input_tensor, xs)

# Plot attributions
from utils import plot, plt
%matplotlib inline

n_cols = 4
n_rows = int(len(attributions) / 2)
fig, axes = plt.subplots(nrows=n_rows, ncols=n_cols, figsize=(3*n_cols, 3*n_rows))

for i, a in enumerate(attributions):
    row, col = divmod(i, 2)
    plot(xs[i].reshape(28, 28), cmap='Greys', axis=axes[row, col*2]).set_title('Original')
    plot(a.reshape(28,28), xi = xs[i], axis=axes[row,col*2+1]).set_title('Attributions')
marcoancona commented 6 years ago

Thanks for the notice. I could reproduce it and will look into it.

marcoancona commented 6 years ago

We found that the code works correctly but the mentioned problem appears if you run the notebook cell that creates and trains the network more than once. This is because a new graph is created in the active Tensorflow session. To solve the problem, please add K.clear_session() at the beginning of the notebook cell where the model is first created.