sicara / tf-explain

Interpretability Methods for tf.keras models with Tensorflow 2.x
https://tf-explain.readthedocs.io
MIT License
1.02k stars 110 forks source link

Gradient is of type None #102

Closed JulesScholler closed 3 years ago

JulesScholler commented 4 years ago

I'm trying to un GradCAM on my own model (modified VGG16 with input shape (1,100,100,3) and 5 classes).

When I try to run :

`# Instantiation of the explainer explainer = GradCAM()

Call to explain() method

output = explainer.explain(validation_data, model, 'block5_conv3', 2)`

Then in the function get_gradients_and_filters(model, images, layer_name, class_index):

grads = tape.gradient(loss, conv_outputs)

grads is 'NoneType' so it cannot continue.

I'm not sure why I get this error.

RaphaelMeudec commented 4 years ago

Can you provide a sample code that triggers the error?

JulesScholler commented 4 years ago

Yes:

model = load_model(r'.\data_cumsum\model_vgg_cumsum.h5')
# Unwrap layers (1st layer is VGG16 without the top)
tmp = Sequential()
for l in model.layers[0].layers:
    tmp.add(l)
for i,l in enumerate(model.layers):
    if i!=0:
       tmp.add(l)
model = tmp
model.get_layer('block5_conv3').trainable = True
#Load a single image for test
u = imread(r'.\data_cumsum\train\3\257.tif')
u = preprocess_input(np.expand_dims(u,0).astype('float64'))
validation_data = (u,2)
#Instantiation of the explainer
explainer = GradCAM()
#Call to explain() method
output = explainer.explain(validation_data, model, 'block5_conv3', 2)

Then I get this error:

TypeError: in converted code:

C:\Users\Jules Scholler\Anaconda3\envs\keras-gpu\lib\site-packages\tf_explain\core\grad_cam.py:88 get_gradients_and_filters  *
    guided_grads = (

TypeError: '>' not supported between instances of 'NoneType' and 'int'

Which is linked to the grads not being computed.

gabarlacchi commented 4 years ago

Any solution for this?

JulesScholler commented 4 years ago

I switched to PyTorch and found more useful implementations. My feel is that Tensorflow is a pain in the ass to debug between versions.

craymichael commented 3 years ago

I ran into this issue which is fixed merely by ensuring that the input data to the grad_model is watched by GradientTape.

Quick fix while the PR is merged, see one-liner in: https://github.com/sicara/tf-explain/pull/167/files