sicara / tf-explain

Interpretability Methods for tf.keras models with Tensorflow 2.x
https://tf-explain.readthedocs.io
MIT License
1.02k stars 112 forks source link

[NOTE] tf version issue: tf.function-decorated function tried to create variables on non-first #149

Closed vscv closed 3 years ago

vscv commented 4 years ago

This error occurred when using tf2.3. somehow tf-explain work just fine with tf2.2.

` import tensorflow as tf from tf_explain.core.grad_cam import GradCAM

IMAGE_PATH = './0_in.jpg'

all default from github

Load pretrained model or your own

model = tf.keras.applications.vgg16.VGG16(weights="imagenet", include_top=True)

Load a sample image (or multiple ones)

img = tf.keras.preprocessing.image.load_img(IMAGE_PATH, target_size=(224, 224)) img = tf.keras.preprocessing.image.img_to_array(img) data = ([img], None)

Start explainer

explainer = GradCAM() grid = explainer.explain(data, model, class_index=281) # 281 is the tabby cat index in ImageNet

explainer.save(grid, ".", "grad_cam.png") `

Hassanfarooq92 commented 3 years ago

you need to do a little modification of the tf-explain code. There are two ways: 1) remove tf.function from "def get_gradients_and_filters" in grad_cam.py however if you want to retain tf.fucntion 2) move grad_model generation from "def get_gradients_and_filters" to "def explain" and pass it via function.