raghakot / keras-vis

Neural network visualization toolkit for keras
https://raghakot.github.io/keras-vis
MIT License
2.98k stars 660 forks source link

AttributeError: module 'tensorflow' has no attribute 'get_default_graph' #226

Open franksacco opened 4 years ago

franksacco commented 4 years ago

Hi, I am trying to use your toolkit in Google Colaboratory starting from the example Attention on ResNet50 (Saliency and grad-CAM). Python version: 3.6.9 TensorFlow version: 2.2.0 Due to the non-updated version of keras-vis in pip, I installed the package with:

pip install git+git://github.com/raghakot/keras-vis.git --upgrade --no-deps

However, when I try to execute this part:

grads = visualize_saliency(model, layer_idx, filter_indices=20, 
                           seed_input=img1, backprop_modifier=modifier)

with modifier = 'guided', I get this error:

/usr/local/lib/python3.6/dist-packages/vis/visualization/saliency.py in visualize_saliency(model, layer_idx, filter_indices, seed_input, wrt_tensor, backprop_modifier, grad_modifier, keepdims)
    125     if backprop_modifier is not None:
    126         modifier_fn = get(backprop_modifier)
--> 127         model = modifier_fn(model)
    128 
    129     # `ActivationMaximization` loss reduces as outputs get large, hence negative gradients indicate the direction

/usr/local/lib/python3.6/dist-packages/vis/backprop_modifiers.py in guided(model)
     15         (https://arxiv.org/pdf/1412.6806.pdf)
     16     """
---> 17     return backend.modify_model_backprop(model, 'guided')
     18 
     19 

/usr/local/lib/python3.6/dist-packages/vis/backend/tensorflow_backend.py in modify_model_backprop(model, backprop_modifier)
     93 
     94         # 3. Create graph under custom context manager.
---> 95         with tf.get_default_graph().gradient_override_map({'Relu': backprop_modifier}):
     96             #  This should rebuild graph with modifications.
     97             modified_model = load_model(model_path)

AttributeError: module 'tensorflow' has no attribute 'get_default_graph'

The only way to make this code works is to modify the source code of keras-vis replacing tf.get_default_graph() with tf.compat.v1.get_default_graph() in backend/tensorflow_backend.py:95.

Is it a bug or am I doing something wrong?

offchan42 commented 4 years ago

I have this issue too trying to run mnist example. I think it's because this library does not support tensorflow 2.