chr5tphr / zennit

Zennit is a high-level framework in Python using PyTorch for explaining/exploring neural networks using attribution methods like LRP.
Other
191 stars 32 forks source link

Check for requires_grad in pre_forward #45

Closed chr5tphr closed 3 years ago

chr5tphr commented 3 years ago

Previously, when checking whether the gradient was required to determine whether to apply hooks, only the existance of a grad_fn of the input was checked. This was insufficient, since the first layer input may not have a grad_fn yet, but still require a gradient. Now, the input is checked for requires_grad instead, since a grad_fn is not needed, because of the subsequent Identity.apply. It is still sufficient to check for grad_fn for the output, since the output will always have a grad_fn if a gradient is required.

Closes #44