marcoancona / DeepExplain

A unified framework of perturbation and gradient-based attribution methods for Deep Neural Networks interpretability. DeepExplain also includes support for Shapley Values sampling. (ICLR 2018)
https://arxiv.org/abs/1711.06104
MIT License
725 stars 133 forks source link

using DeepExplain for time-series classification #19

Open tmazaev opened 6 years ago

tmazaev commented 6 years ago

Hi, I am applying this code to a CNN classifying a dataset of 1D-time series. I use the following code for using DeepExplain on the learned model.

def calculate_attributions(X, y_onehot, data_nrs_to_explain, explanation_method='occlusion'):

attributions_list = []

with DeepExplain(session=K.get_session()) as de: # <-- init DeepExplain context
    # Need to reconstruct the graph in DeepExplain context, using the same weights.
    # With Keras this is very easy:

    # 1. Get the input tensor to the original model
    input_tensor = custom_cnn_model.layers[0].input

    # 2. We now target the output of the last dense layer (pre-softmax)
    # To do so, create a new model sharing the same layers untill the last dense
    fModel = Model(inputs=input_tensor, outputs = custom_cnn_model.layers[-1].output)
    target_tensor = fModel(input_tensor)

    for nr in data_nrs_to_explain:
        xs = X[nr]
        ys = y_onehot[nr]
        xs = xs.reshape(1,2301,1)
        attributions = de.explain(explanation_method, target_tensor * ys, input_tensor, xs)
        attributions_list.append(attributions)

return attributions_list

The returned attributions seem to make sense, but like you mention in the ICLR paper it is difficult to distinguish errors of the model from errors of the attribution method explaining the model. Are there any caveats when applying your code to time-series (instead of images)?

marcoancona commented 6 years ago

I have to say that we did not try any time-serie dataset. I would start with the occlusion method, as for 1-D data of shape (2301, 1) that should be fast enough.

teimoorbah commented 4 years ago

@marcoancona hi i have an ecg dataset with the shape (n_samples, time_steps, n_features). have can i apply DeepExplain to my LSTM method and plot the result. i can generate the attribution matrix but unable to run it.