SIMEXP / gcn_package

The lab repository for GCN
MIT License
0 stars 2 forks source link

ReLU layer instead of function #10

Closed htwangtw closed 2 years ago

htwangtw commented 2 years ago

Numerically, this performs the same way as the current version, but making the hook function registration implementation easier:

# switch the model to evaluation mode
gcn.eval()

# use a hook to remove negative gradient
def relu_hook_function(module, grad_in, grad_out):
    if isinstance(module, torch.nn.ReLU):
        return (torch.clamp(grad_in[0], min=0.0), )

for pos, module in gcn.features._modules.items():
    if isinstance(module, torch.nn.ReLU):
        print(gcn.named_modules())
        module.register_backward_hook(relu_hook_function)