chr5tphr / zennit

Zennit is a high-level framework in Python using PyTorch for explaining/exploring neural networks using attribution methods like LRP.
Other
188 stars 32 forks source link

Core: GradOutHooks #141

Open chr5tphr opened 2 years ago

chr5tphr commented 2 years ago

I am currently working on GradOutHook, which is different from the current zennit.core.Hook in that instead of overwriting the full gradient of the module, it only changes the gradient output. For Composites using zennit.core.Hook, only a single Hook can be attached at a time, because it will change the full gradient of the module. The GradOutHook can modify the output gradient multiple times, and can be used together with zennit.core.Hook. This can lead to using multiple Composites at a time. Another way to enable multiple hooks would be to let the module_map function of Composites allow to return a tuple of Hooks to be applied.

The main use case for this is to mask or re-weight neurons, mainly to support LRP for GNNs. Another use-case is to mask certain neurons to get LRP for a subset of features/concepts.

This will somewhat change the Hook-inheritance, where a HookBase will be added to specify the interface necessary for all Hooks. Also, I am considering to add a Mask rule to zennit/rule.py which takes a function or a tensor to mask the gradient output, which can be used without subclassing the planned GradOutHook.