Closed linchundan88 closed 5 years ago
I modify method.py from line 196 as follows:
shape_baseline = list(self.baseline.shape)
shape_model_input = self.X.get_shape().as_list()[1:]
if len(shape_baseline) != len(shape_model_input):
raise RuntimeError('Baseline shape %s does not match expected shape %s'
% (self.baseline.shape, self.X.get_shape().as_list()[1:]))
for i in range(len(shape_baseline)):
if (shape_baseline[i] != shape_model_input[i]) and (shape_model_input[i] is not None ):
raise RuntimeError('Baseline shape %s does not match expected shape %s'
% (self.baseline.shape, self.X.get_shape().as_list()[1:]))
self.baseline = np.expand_dims(self.baseline, 0)
Makes sense to me. Does it work with your model using such a fix?
Makes sense to me. Does it work with your model using such a fix?
Yes, it works.
I will push the changes in the next days then.
I use deeplift. If a model's input dimension is (None,None,None,3), when using deeplift, it will raise an exception.
After setting a baseline, it still raise an exception: "RuntimeError: Baseline shape (299, 299, 3) does not match expected shape [None, None, 3]" for example: baseline = np.zeros((image_size, image_siz e, 3)) explainer = de.get_explainer('deeplift', target_tensor, input_tensor, baseline=baseline)
I modified the method "_set_check_baseline" in file method.py from line 196,
It runs well.