marcoancona / DeepExplain

A unified framework of perturbation and gradient-based attribution methods for Deep Neural Networks interpretability. DeepExplain also includes support for Shapley Values sampling. (ICLR 2018)
https://arxiv.org/abs/1711.06104
MIT License
729 stars 133 forks source link

3 channel images #41

Closed sherlock42 closed 5 years ago

sherlock42 commented 5 years ago

how can we use deep explain for 3 channel images with keras backend?

marcoancona commented 5 years ago

Exactly like for 1-channel images, except that attributions for the three channels should then be summed up together. See https://github.com/marcoancona/DeepExplain/blob/master/examples/mint_cnn_keras.ipynb

sherlock42 commented 5 years ago

I was trying to use it for a dataset with only 2 classes. I tried to run deepexplain for one single image of my datastet (The label of the image is 0)

`with DeepExplain(session=K.get_session()) as de: # <-- init DeepExplain context

Need to reconstruct the graph in DeepExplain context, using the same weights.

# With Keras this is very easy:
# 1. Get the input tensor to the original model
input_tensor = model3.layers[0].input

# 2. We now target the output of the last dense layer (pre-softmax)
# To do so, create a new model sharing the same layers untill the last dense (index -2)
fModel = Model(inputs=input_tensor, outputs = model3.layers[-2].output)
target_tensor = fModel(input_tensor)

xs = img_tensor
y_test= np.array([0])
ys=y_test

attributions_gradin = de.explain('grad*input', target_tensor, input_tensor, xs, ys=ys)
#attributions_sal   = de.explain('saliency', target_tensor, input_tensor, xs, ys=ys)
#attributions_ig    = de.explain('intgrad', target_tensor, input_tensor, xs, ys=ys)
#attributions_dl    = de.explain('deeplift', target_tensor, input_tensor, xs, ys=ys)
#attributions_elrp  = de.explain('elrp', target_tensor, input_tensor, xs, ys=ys)
#attributions_occ   = de.explain('occlusion', target_tensor, input_tensor, xs, ys=ys)`

but I keep getting the following error

`--------------------------------------------------------------------------- ValueError Traceback (most recent call last)

in () 13 ys = y_test 14 ---> 15 attributions_gradin = de.explain('grad*input', target_tensor, input_tensor, xs, ys=ys) 16 #attributions_sal = de.explain('saliency', target_tensor, input_tensor, xs, ys=ys) 17 #attributions_ig = de.explain('intgrad', target_tensor, input_tensor, xs, ys=ys) /content/src/deepexplain/deepexplain/tensorflow/methods.py in explain(self, method, T, X, xs, ys, batch_size, **kwargs) 619 def explain(self, method, T, X, xs, ys=None, batch_size=None, **kwargs): 620 explainer = self.get_explainer(method, T, X, **kwargs) --> 621 return explainer.run(xs, ys, batch_size) 622 623 @staticmethod /content/src/deepexplain/deepexplain/tensorflow/methods.py in run(self, xs, ys, batch_size) 212 def run(self, xs, ys=None, batch_size=None): 213 self._check_input_compatibility(xs, ys, batch_size) --> 214 results = self._session_run(self.explain_symbolic(), xs, ys, batch_size) 215 return results[0] if not self.has_multiple_inputs else results 216 /content/src/deepexplain/deepexplain/tensorflow/methods.py in _session_run(self, T, xs, ys, batch_size) 149 150 if batch_size is None or batch_size <= 0 or num_samples <= batch_size: --> 151 return self._session_run_batch(T, xs, ys) 152 else: 153 outs = [] /content/src/deepexplain/deepexplain/tensorflow/methods.py in _session_run_batch(self, T, xs, ys) 133 if self.keras_learning_phase is not None: 134 feed_dict[self.keras_learning_phase] = 0 --> 135 return self.session.run(T, feed_dict) 136 137 def _session_run(self, T, xs, ys=None, batch_size=None): /usr/local/lib/python3.6/dist-packages/tensorflow/python/client/session.py in run(self, fetches, feed_dict, options, run_metadata) 927 try: 928 result = self._run(None, fetches, feed_dict, options_ptr, --> 929 run_metadata_ptr) 930 if run_metadata: 931 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr) /usr/local/lib/python3.6/dist-packages/tensorflow/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata) 1126 'which has shape %r' % 1127 (np_val.shape, subfeed_t.name, -> 1128 str(subfeed_t.get_shape()))) 1129 if not self.graph.is_feedable(subfeed_t): 1130 raise ValueError('Tensor %s may not be fed.' % subfeed_t) ValueError: Cannot feed value of shape (1, ) for Tensor 'Placeholder_31:0', which has shape '(?, 512)'`