sicara / tf-explain

Interpretability Methods for tf.keras models with Tensorflow 2.x
https://tf-explain.readthedocs.io
MIT License
1.02k stars 112 forks source link

GradCAM varying on multiple calls #153

Open palatos opened 3 years ago

palatos commented 3 years ago

I'm running the GradCAM function on a pretrained model for a given image, and every time I generate the heatmap I get a different one. This happens whether or not I use the guided gradients option. The heatmaps are very different, too.

Is this expected?

RaphaelMeudec commented 3 years ago

This is not expected at all. Can you provide a sample code so I can easily reproduce this issue? Thanks for raising the issue

palatos commented 3 years ago

Sure! I think I actually tracked down the problem to augmentation layers I was using. I created a simple mock version of my problem in this notebook: https://github.com/palatos/mynotes/blob/main/gradcam-with-augmentation-layers-problem.ipynb

If you run this you'll notice the augmentation layers screw up the GradCAM output because they are active by default when you add them to the model. They have a parameter "training = True" that gets turned off during predictions or evaluations, but the operations performed by explainer.explanation() don't count as prediction/evaluation, so the augmentation layer is active.

The only workarounds I found so far were mangling the already trained model to try and remove the augmentation layers, or manually redefining the model and explicitly passing the "training = False" argument in the relevant layers.

Neither of these solutions feels good though, because they require me to reconstruct the model in some way. It would be ideal to fix this only using the pretrained model somehow. I also think this might be a problem with Batch Normalization layers, which also have the same "training" parameter.

Let me know if you find a different solution?

RaphaelMeudec commented 2 years ago

Let's make sure in v1.0.0 that the parameter training=True is passed to the model.predict in tf-explain methods