Numerically, this performs the same way as the current version, but making the hook function registration implementation easier:
# switch the model to evaluation mode
gcn.eval()
# use a hook to remove negative gradient
def relu_hook_function(module, grad_in, grad_out):
if isinstance(module, torch.nn.ReLU):
return (torch.clamp(grad_in[0], min=0.0), )
for pos, module in gcn.features._modules.items():
if isinstance(module, torch.nn.ReLU):
print(gcn.named_modules())
module.register_backward_hook(relu_hook_function)
Numerically, this performs the same way as the current version, but making the hook function registration implementation easier: