chr5tphr / zennit

Zennit is a high-level framework in Python using PyTorch for explaining/exploring neural networks using attribution methods like LRP.
Other
183 stars 33 forks source link

Core: Second Order Gradients #159

Closed chr5tphr closed 1 year ago

chr5tphr commented 1 year ago

Implements #142 and fixes #125

HeinrichAD commented 1 year ago

Hi @chr5tphr

I will provide my feedback here instead of the issue itself.

I didn't check the attribution maps in detail, but your code example (here) seems to work fine. (My following examples are based on this example.)

There are some points I would like to mention:

canonizers = None
composite = EpsilonGammaBox(low=-3., high=3., canonizers=canonizers)

def explain_LRP(model: torch.nn.Module, input: torch.Tensor, target: torch.Tensor):
    with composite.context(model) as modified_model:
        outputs = modified_model(input)
        relevance, = torch.autograd.grad(outputs, input, target, create_graph=True)
        return outputs, relevance

outputs, relevance = explain_LRP(model, input, target)
# create a target heatmap, rolled 12 pixels south east
target_heat = torch.roll(relevance.detach(), (12, 12), (2, 3))
loss = ((relevance - target_heat) ** 2).mean()

# deactivate the rule hooks in order to leave the second order gradient untouched
# version 1
with composite.inactive():
   adv_grad, = torch.autograd.grad(loss, input)  # <<-- error because `hook.active` is still True
# version 2
with composite.context(model):
    with composite.inactive():
        adv_grad, = torch.autograd.grad(loss, input)  # <<-- error because `hook.active` is still True

Code Example 2 ```python canonizers = None composite = EpsilonGammaBox(low=-3., high=3., canonizers=canonizers) with Gradient(model=model, composite=composite) as attributor: outputs, relevance = attributor(inputs, torch.eye(1000)[targets]) # create a target heatmap, rolled 12 pixels south east target_heat = torch.roll(relevance.detach(), (12, 12), (2, 3)) loss = ((relevance - target_heat) ** 2).mean() # deactivate the rule hooks in order to leave the second order gradient untouched with attributor.composite.inactive(): adv_grad, = torch.autograd.grad(loss, inputs) # <<-- Error # RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn ```
Code Example 3 ```python def explain_LRP(model: torch.nn.Module, input: torch.Tensor, target: torch.Tensor): canonizers = None composite = EpsilonGammaBox(low=-3., high=3., canonizers=canonizers) with Gradient(model=model, composite=composite) as attributor: outputs, attributions = attributor(input, target) return outputs, attributions outputs, relevance = explain_LRP(model, input, target) # create a target heatmap, rolled 12 pixels south east target_heat = torch.roll(relevance.detach(), (12, 12), (2, 3)) loss = ((relevance - target_heat) ** 2).mean() # now the gradient calculation should be possible by default without any futher deactivation etc. adv_grad, = torch.autograd.grad(loss, input) # <<-- this should work by default ```
Code Example 4 ```python def explain_LRP(model: torch.nn.Module, input: torch.Tensor, target: torch.Tensor): canonizers = None composite = EpsilonGammaBox(low=-3., high=3., canonizers=canonizers) with composite.context(model) as modified_model: outputs = modified_model(input) relevance, = torch.autograd.grad(outputs, input, target, create_graph=True) return outputs, relevance outputs, relevance = explain_LRP(model, input, target) # create a target heatmap, rolled 12 pixels south east target_heat = torch.roll(relevance.detach(), (12, 12), (2, 3)) loss = ((relevance - target_heat) ** 2).mean() # now the gradient calculation should be possible by default without any futher deactivation etc. adv_grad, = torch.autograd.grad(loss, input) # <<-- this should work by default ```

Edit: To be honest I expected something like this to be the default:

Code Example 5 ```python def explain_LRP(model: torch.nn.Module, input: torch.Tensor, target: torch.Tensor): canonizers = None composite = EpsilonGammaBox(low=-3., high=3., canonizers=canonizers) with composite.context(model) as modified_model: outputs = modified_model(input) relevance, = torch.autograd.grad(outputs, input, target, create_graph=True) for hook in composite.hook_refs: hook.active = False return outputs, relevance outputs, relevance = explain_LRP(model, input, target) # create a target heatmap, rolled 12 pixels south east target_heat = torch.roll(relevance.detach(), (12, 12), (2, 3)) loss = ((relevance - target_heat) ** 2).mean() # now the gradient calculation should be possible by default without any futher deactivation etc. adv_grad, = torch.autograd.grad(loss, input) # <<-- OK ```
chr5tphr commented 1 year ago

Hey @HeinrichAD

thanks a lot for your feedback!

For clarification: Hooks are only ever active after composite.register and before composite.remove has been called, which is done at the beginning and the end of the with composite context respectively. The new composite.inactive context adds a new attribute hook.active, which is only used when the hook already exists. It is True by default, because it only exists to temporarily deactivate the hook, which is necessary to compute the second order gradient, with which the hooks interfere normally, while the hook is still alive.

  • What if we do not use composites?
  • Why not add the possibility to activate/deactivate the rule hooks in general? (Sometimes you need to do many other things before you calculate the second order gradient and you do not want to keep this always in mind.)
  • Why does composite.inactive only works inside the original composite.context?

Code Example 1

# deactivate the rule hooks in order to leave the second order gradient untouched
# version 1
with composite.inactive():
  adv_grad, = torch.autograd.grad(loss, input)  # <<-- error because `hook.active` is still True
# version 2
with composite.context(model):
    with composite.inactive():
        adv_grad, = torch.autograd.grad(loss, input)  # <<-- error because `hook.active` is still True
  • Attributors are currently not supported?

Code Example 2

  • What do you think should be the default behavour? (Maybe the most important question smile) Would it not be more plausible if the default is inactive and only if you enter a attributor or composite it would be active? In my optinion the following should be possible by default:

Code Example 3

Code Example 4

Edit: To be honest I expected something like this to be the default: Code Example 5

To summarize, the second order gradient should be possible to compute after destroying the hooks, or while the hooks still exist within with composite.inactive(): or hook.active = False; compute_gradient(); hook.active = True for single hooks. A bug in the code made the destroying-case impossible.

Edit:

Actually, for the intended behaviour, you do not need to loop over the hooks in Code Example 5, it should work without it, since you leave the context and destroy the hooks. The composite.inactive exactly does the loop you show, but it only makes sense if you set them to True again later when re-using the composite.

HeinrichAD commented 1 year ago

Hi @chr5tphr,

It makes much more sense now. Thank you for the detailed clarification. If you want I can start a test run if you think it makes sense now. Otherwise I wait a little longer.

chr5tphr commented 1 year ago

Hey @HeinrichAD

I'm currently working on the rest of the documentation for this, but functionality and tests are now finished (unless I find a bug or something missing).

If you would like to try it out, you can either do so now, or wait a little bit until I also finished the documentation, at which point I will mark this PR ready and merge it in the following days.

Everything should work as expected, and as a bonus, Attributors now also have a .inactive function to compute second order gradients within the with block. There's also a new rule from (Dombrowski et. al., 2019) to change the gradient of ReLU to its smooth variant in order to deal with its otherwise zero (and undefined at zero) second order gradients.

chr5tphr commented 1 year ago

Hey @HeinrichAD

I am done with the PR and would like to merge. If you have not checked already, you can see if it works for you as expected. A preview of the documentation is also available here. Otherwise I will just merge.

HeinrichAD commented 1 year ago

Hi @chr5tphr, I will try to find some time tomorrow.

HeinrichAD commented 1 year ago

Hi @chr5tphr,

First thank you for your effort!

Code

In general, the code looks good. It also works as I expected. Only one thing is confussion me: The output of your example code from the issue #142 generates a complete different output than before (for the 2nd derivative).

Same code generates now this output:

image

NOTE: I do not know which output is correct.

Typos

Since my IDE already points this out for me, here is a list of typos:

Also sometimes it's layer-wise relevance propagation and sometimes layerwise relevance propagation.

chr5tphr commented 1 year ago

Hey @HeinrichAD

thanks a lot again for your feedback.

In general, the code looks good. It also works as I expected. Only one thing is confussion me: The output of your example code from the issue #142 generates a complete different output than before (for the 2nd derivative).

Same code generates now this output:

NOTE: I do not know which output is correct.

The first version was actually wrong. The problem was that BasicHook, within the backward function, when using torch.autograd.grad, did not set create_graph=True, which meant that the gradient was only computed such that the contribution-weighting of the input was handled like a constant. This means that the resulting second order gradient was not computed through the whole model, but just the first layer. This is why the gradient also looked so clean; it was just a difference of the contribution in the first layer (divided by x).

Typos

Since my IDE already points this out for me, here is a list of typos:

* CONTRIBUTING.md#50 numpy codestyle

* docs/source/index.rst#5 Propagation

* docs/source/getting-started.rst#149 instantiate

* docs/source/how-to/visualize-results.rst#444 accessed

* docs/source/how-to/write-custom-canonizer.rst#113 torch

* docs/source/tutorial/image-...ipynb section 3.2 cell 1 line 10 gamma

* src/zennit/core.py#165 lengths

* tests/test_attribution.py#91 preferred

* tests/test_attribution.py#140 SmoothGrad

* tests/test_canonizers.py#120 AttributeCanonizer

* tests/test_canonizers.py#141 whether

* shared/scripts/palette_fit.py#48 brightness

Also sometimes it's layer-wise relevance propagation and sometimes layerwise relevance propagation.

I will add a quick follow-up PR to fix these, since many files where not touched in this PR, and I prefer to not touch files for typos etc. if there was no change in that file.

HeinrichAD commented 1 year ago

Thank you for the explanation. In this case the PR gets a ready to go from my side 😄.