sicara / tf-explain

Interpretability Methods for tf.keras models with Tensorflow 2.x
https://tf-explain.readthedocs.io
MIT License
1.02k stars 112 forks source link

VanillaGradients computes gradients with regards to the likelihood, not the score #159

Closed CosmicGans closed 3 years ago

CosmicGans commented 3 years ago

Thanks for the very useful package!

According to the original paper, the Vanilla gradients should be computed based on "the gradient of the class score with respect to the input image."
However, the Vanilla Gradient example computes the gradient with regards to the likelihoods (the softmax layer is the last one in VGG16). To reflect the approach in the paper, it seems that "include_top=False" would be more suitable. Using the class score instead of likelihood also makes more sense (at least to me ;) in case when one class is already close to 100% probaility.

Also, most people will plug in their model including the classification layer, so I think that putting a reminder in the docs to remove the classification layer would be very helpful.

AlexKubiesa commented 3 years ago

@RaphaelMeudec here are my thoughts on the issue.

Vanilla Gradients is intended for convolutional classifiers, which are likely to have tf.keras.layers.Dense as the last layer, with softmax activation. Unfortunately, the activation function is part of the Dense object, so there is no way to easily compute the layer without also applying the activation function.

Here are some potential solutions:

  1. One option would be to tell users of this function to temporarily set activation = None before calling VanillaGradients, but that wouldn't work for the callback because users would want to train the real model, not a modified version.

  2. Another option would be to tell users to create or modify their models to have an Activation or Softmax layer as the last layer, and have no activation in the previous (Dense) layer. Then VanillaGradients could operate on the second-from-last layer. This would work for both the normal and callback versions.

  3. A third option might be to make VanillaGradients.explain clone the last layer, replace the activation function on the cloned version, make a new model using the cloned layer as the output, and find the gradients of the new model. But, it seems wasteful to do this every time the function is called and I don't think it would work well within a callback.

I think Option 2 sounds the nicest, and is reasonable if we document the expectation clearly.

Let me know what you think.

RaphaelMeudec commented 3 years ago

@AlexKubiesa I agree with you that it seems the most reasonable option. One thing that bothers me though is that it makes it nearly impossible to use the vanilla gradients callbacks (as you would train the model with your last activation). Maybe what we could do is a mix of 2 & 3? If the last layer is an activation layer, we pop it in the callback computation, else we raise a warning/log to indicate user that the callback is potentially flawd due to the last layer being a Dense + activation.

What do you think of this?

AlexKubiesa commented 3 years ago

I think logging a warning makes sense.

I don't see how Option 2 makes the callback difficult. We can make a new model and apply that to the input instead of the original. If the last layer is an activation layer, we can do:

score_node = model.layers[-1].input
score_model = tf.keras.Model(inputs=model.inputs, outputs=[score])
scores = score_model(inputs)

Then take the gradient of the score with respect to the inputs.

AlexKubiesa commented 3 years ago

I don't think we can perform a comprehensive check of the model architecture in all cases.

If the network is entirely sequential and has only one set of inputs and outputs, then layers is just the full list of layers in order, and we can check that layers[-2] is a Dense layer (or similar) and layers[-1] is a Softmax layer (or similar).

I assume we're not supporting models with multiple inputs because the visualisation needs to be a single image.

There may be models with multiple outputs, where someone has trained a multi-feature classifier, and in this case it's harder to predict the order of model.layers.

It gets even more difficult with models that "branch out", apply different layers to the same inputs and then recombine them, or models that apply the same layer multiple times at different stages of the computation. Should we worry about checking for kinds of architectures?

Maybe we could require the user to pass in the specific layer they want visualised (as well as the model)?

Or we choose not to support the complex use cases and just require the model to be sequential (enough) with a single input and output tensor?

RaphaelMeudec commented 3 years ago

Closing this issue as the fix has been merged on master