This issue aims at implementing the Gradients*Inputs method.
At the moment, there is only the Vanilla Gradients which is implemented. The difference between Vanilla Gradients (VG) and GradientsInputs (GI) is that instead of returning the pure gradients for a given inputs, it returns the gradients ponderated by the input values.
The operation to perform is to take the input (shape HW3) and the gradients (shape HW3), and multiple each of the 3 channels to obtain the ponderated gradients (shape HW3). From there, we create a visualization (either HW3 or HW1).
How
Part 1: Create the core algorithm
[x] Create a gradients_inputs.py file
[x] Create a GradientsInputs class
[x] Implement the .explain method
[x] Signature should be (validation_data, model, class_index)
[x] Complete docstring
[x] Decompose images, _ = validation_data
[x] Add a compute_gradients_dot_inputs() method in the class
Explanation
This issue aims at implementing the Gradients*Inputs method.
At the moment, there is only the Vanilla Gradients which is implemented. The difference between Vanilla Gradients (VG) and GradientsInputs (GI) is that instead of returning the pure gradients for a given inputs, it returns the gradients ponderated by the input values.
The operation to perform is to take the input (shape HW3) and the gradients (shape HW3), and multiple each of the 3 channels to obtain the ponderated gradients (shape HW3). From there, we create a visualization (either HW3 or HW1).
How
Part 1: Create the core algorithm
gradients_inputs.py
fileGradientsInputs
class.explain
method(validation_data, model, class_index)
images, _ = validation_data
compute_gradients_dot_inputs()
method in the classgradient.tape
, multiply it by the inputstf.function
decoratorsave
method (check this PR)tf_explain.core.__init__
Part 2: Add the corresponding callback
tf_explain.callbacks.gradients_inputs
tf_explain.callbacks.gradients
tf_explain.callbacks.__init__
tests/integration/test_keras_api.py
Part 3: Examples, Docs
examples.mnist
, add the implemented callbackexamples.core
file for this methodREADME.md
(follow the existing scheme)docs/source/methods
Checks
cd docs/ & make html
, then openbuild/html/index.html
)