sicara / tf-explain

Interpretability Methods for tf.keras models with Tensorflow 2.x
https://tf-explain.readthedocs.io
MIT License
1.02k stars 111 forks source link

Add option to use raw gradients instead of guided gradients on GradCam #131

Closed Tauranis closed 4 years ago

Tauranis commented 4 years ago

The current GradCAM implementation uses guided grads to generate the heatmaps.

The purpose of this PR is to offer the option to apply the GradCAM according to the equation (11) from the paper Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization, thus using the raw gradients instead of the guided grads. This would be done by the parameter use_guided_grads

import tf_explain

gradcam = tf_explain.core.GradCAM()
grid_gradcam = gradcam.explain(
    validation_data=EVAL_DATA 
    model=MODEL, 
    class_index=CLASS_INDEX,
    layer_name=LAYER_NAME,
    colormap=COLOR_MAP,
    use_guided_grads=False) # <----

grid_guided_gradcam = gradcam.explain(
    validation_data=EVAL_DATA 
    model=MODEL, 
    class_index=CLASS_INDEX,
    layer_name=LAYER_NAME,
    colormap=COLOR_MAP,
    use_guided_grads=True) # <----

## Using Keras Callbacks

from tf_explain.callbacks.grad_cam import GradCAMCallback

callback_guided_gradcam= GradCAMCallback(
    validation_data=EVAL_DATA 
    class_index=CLASS_INDEX,
    layer_name=LAYER_NAME,
    colormap=COLOR_MAP,
    use_guided_grads=True) # <----

callback_gradcam = GradCAMCallback(
    validation_data=EVAL_DATA 
    class_index=CLASS_INDEX,
    layer_name=LAYER_NAME,
    colormap=COLOR_MAP,
    use_guided_grads=False) # <----

Why such feature?

First, offer the users the possibility to choose what fits best for them.

Second, doing some research on a problem I'm working on my private studies, GradCAM without guided grads shows better results than with the original guided grads.

To exemplify how GradCAM performs better without guided grads, see this Colab where the dataset Beans is used for image classification.

With Guided GradCAM the heatmap is messy (third heatmap) Guided GradCAM

However with GradCAM without guided gradients the heatmap looks better GradCAM