marcoancona / DeepExplain

A unified framework of perturbation and gradient-based attribution methods for Deep Neural Networks interpretability. DeepExplain also includes support for Shapley Values sampling. (ICLR 2018)
https://arxiv.org/abs/1711.06104
MIT License
720 stars 133 forks source link

Fetch argument None has invalid type <class 'NoneType'> #66

Open ahof1704 opened 3 years ago

ahof1704 commented 3 years ago

Hi,

I am trying to use DeepExplain with the Concept Saliency Map. However, I run into the following issue:

import keras
sess = K.get_session()
print('sess: ',sess)
from ConceptSaliencyMaps.deepexplain.tensorflow import DeepExplain
from ConceptSaliencyMaps.deepexplain.utils import preprocess

list_files = []
all_files = train_files + test_files
for file_name in files_max:
    for file_name2 in all_files:
        if file_name in file_name2:
            list_files.append(file_name2)

test_set2 = zfish_age(list_files, path_to_save = path_to_augmented, test=True, transform = True, new_channel=new_channel, new_size_frame=size_frame, 
                     verbose=False)
test_generator2 = data.DataLoader(test_set2,batch_size=1,
                                       shuffle=False,
                                       num_workers=20)

input_img = keras.Input(shape=(50, 128, 128)) 

with DeepExplain(session=sess, graph=sess.graph) as de:
    with torch.no_grad():
        for i, d in enumerate(test_generator2): 
            xis, _, _, labels_name = d
            print('labels_name: {}'.format(labels_name))

            input_tensor = input_img
            img_array = xis.reshape([1,50,128,128])
            ris, zis = model(xis.to(device))
            print('zis.shape: ',zis.shape) # torch.Size([1, 256])
            latents = reducer.transform(zis.cpu().detach())
            print('latents.shape: ',latents.shape) # (1, 2)
            method = 'guidedbp'

            concept_score = [K.sum(latents*i) for i in concept_vectors[attr]]
            attributions_guided = [de.explain(method, i, input_tensor, img_array) for i in concept_score]```

Error:

TypeError                                 Traceback (most recent call last)
<ipython-input-169-177871cfe4fc> in <module>
     73 
     74             concept_score = [K.sum(latents*i) for i in concept_vectors[attr]]
---> 75             attributions_guided = [de.explain(method, i, input_tensor, img_array) for i in concept_score]

<ipython-input-169-177871cfe4fc> in <listcomp>(.0)
     73 
     74             concept_score = [K.sum(latents*i) for i in concept_vectors[attr]]
---> 75             attributions_guided = [de.explain(method, i, input_tensor, img_array) for i in concept_score]

../ConceptSaliencyMaps/deepexplain/tensorflow/methods.py in explain(self, method, T, X, xs, **kwargs)
    733         _ENABLED_METHOD_CLASS = method_class
    734         method = _ENABLED_METHOD_CLASS(T, X, xs, self.session, self.keras_phase_placeholder, **kwargs)
--> 735         result = method.run()
    736         if issubclass(_ENABLED_METHOD_CLASS, GradientBasedMethod) and _GRAD_OVERRIDE_CHECKFLAG == 0:
    737             warnings.warn('DeepExplain detected you are trying to use an attribution method that requires '

../ConceptSaliencyMaps/deepexplain/tensorflow/methods.py in run(self)
    463         for alpha in list(np.linspace(1. / self.steps, 1.0, self.steps)):
    464             xs_mod = [xs * alpha for xs in self.xs] if self.has_multiple_inputs else self.xs * alpha
--> 465             _attr = self.session_run(attributions, xs_mod)
    466             if gradient is None: gradient = _attr
    467             else: gradient = [g + a for g, a in zip(gradient, _attr)]

../ConceptSaliencyMaps/deepexplain/tensorflow/methods.py in session_run(self, T, xs)
     94         if self.keras_learning_phase is not None:
     95             feed_dict[self.keras_learning_phase] = 0
---> 96         return self.session.run(T, feed_dict)
     97 
     98     def _set_check_baseline(self):

../lib/python3.7/site-packages/tensorflow_core/python/client/session.py in run(self, fetches, feed_dict, options, run_metadata)
    954     try:
    955       result = self._run(None, fetches, feed_dict, options_ptr,
--> 956                          run_metadata_ptr)
    957       if run_metadata:
    958         proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

../lib/python3.7/site-packages/tensorflow_core/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
   1163     # Create a fetch handler to take care of the structure of fetches.
   1164     fetch_handler = _FetchHandler(
-> 1165         self._graph, fetches, feed_dict_tensor, feed_handles=feed_handles)
   1166 
   1167     # Run request and get response.

..lib/python3.7/site-packages/tensorflow_core/python/client/session.py in __init__(self, graph, fetches, feeds, feed_handles)
    472     """
    473     with graph.as_default():
--> 474       self._fetch_mapper = _FetchMapper.for_fetch(fetches)
    475     self._fetches = []
    476     self._targets = []

../lib/python3.7/site-packages/tensorflow_core/python/client/session.py in for_fetch(fetch)
    264     elif isinstance(fetch, (list, tuple)):
    265       # NOTE(touts): This is also the code path for namedtuples.
--> 266       return _ListFetchMapper(fetch)
    267     elif isinstance(fetch, collections_abc.Mapping):
    268       return _DictFetchMapper(fetch)

../lib/python3.7/site-packages/tensorflow_core/python/client/session.py in __init__(self, fetches)
    373     """
    374     self._fetch_type = type(fetches)
--> 375     self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches]
    376     self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers)
    377 

../lib/python3.7/site-packages/tensorflow_core/python/client/session.py in <listcomp>(.0)
    373     """
    374     self._fetch_type = type(fetches)
--> 375     self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches]
    376     self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers)
    377 

../lib/python3.7/site-packages/tensorflow_core/python/client/session.py in for_fetch(fetch)
    261     if fetch is None:
    262       raise TypeError('Fetch argument %r has invalid type %r' %
--> 263                       (fetch, type(fetch)))
    264     elif isinstance(fetch, (list, tuple)):
    265       # NOTE(touts): This is also the code path for namedtuples.

TypeError: Fetch argument None has invalid type <class 'NoneType'>

I have reported the issue in the Concept Sal. Maps github as well, but both the developer and I believe the issue is related to DeepExplain. Any insights into this problem?

Please let me know if you need any further info related to the problem. Thanks!