TL;DR :bug: If your model involves using a full_backward_hook, computing derivatives throws a missing setup_context error. As a full backward hook is technically a torch.autograd.Function, and hence requires a setup_context method in pytorch2, which it doesn't seem to have by default.
Here's a minimal reproducible example,
import torch
from torch import nn, Tensor
from torch.func import vmap, jacrev, functional_call
from typing import Tuple
from collections import defaultdict
class model(nn.Module):
def __init__(self, num_input, num_hidden):
super(model, self).__init__()
self.fc1 = nn.Linear(num_input, num_hidden)
self.fc2 = nn.Linear(num_hidden, 1)
self.act_func = nn.Tanh()
def forward(self, x):
x = self.fc1(x)
x = self.act_func(x)
x = self.fc2(x)
return x
num_samples = 4096
num_input=2
num_hidden=64
device=torch.device("cpu")
state = defaultdict(dict) #equivalent to optim.state
def forward_pre_hook(module: nn.Module, input: Tuple[Tensor]) -> None:
a=input[0]
if(module.bias is not None):
_shape = [*a.shape]
_shape[-1]=1
ones=torch.ones(*_shape, device=a.device, dtype=a.dtype)
a = torch.cat([a, ones], dim=-1)
a = a - torch.mean(a, keepdim=True, dim=0)
state[module]['a'] = a
def full_backward_hook(module: nn.Module, grad_input: Tuple[Tensor], grad_output: Tuple[Tensor]) -> None:
e = grad_output[0] * grad_output[0].size(0)
e = e - torch.mean(e, keepdim=True, dim=0)
state[module]['e'] = e
#The input
x = torch.randn(num_samples, num_input, device=device)
#Our model
net = model(num_input=num_input,
num_hidden=num_hidden)
#Add the hooks
for mod in net.modules():
if(mod.__class__.__name__ == 'Linear'):
mod.register_forward_pre_hook(forward_pre_hook)
mod.register_full_backward_hook(full_backward_hook)
y = net(x) #compute output
#Compute trace of Hessian
def calc_hessian_trace(params, x):
def output(params, x):
return functional_call(net, params, x)
_hessian = jacrev(jacrev(output, argnums=(1)), argnums=(1))(params, x)
return _hessian.diagonal(0,-2,-1).sum(-1)
laplacian = vmap(calc_hessian_trace, in_dims=(None, 0))(dict(net.named_parameters()), x) #fails
The resultant error (with complete stack trace) is,
Traceback (most recent call last):
File "~/pytorch2/experimental/hooks_on_aux_loss.py", line 71, in <module>
laplacian = vmap(calc_hessian_trace, in_dims=(None, 0))(dict(net.named_parameters()), x) #fails
File "~/anaconda3/envs/pytorch2/lib/python3.10/site-packages/torch/_functorch/vmap.py", line 434, in wrapped
return _flat_vmap(
File "~/anaconda3/envs/pytorch2/lib/python3.10/site-packages/torch/_functorch/vmap.py", line 39, in fn
return f(*args, **kwargs)
File "~/anaconda3/envs/pytorch2/lib/python3.10/site-packages/torch/_functorch/vmap.py", line 619, in _flat_vmap
batched_outputs = func(*batched_inputs, **kwargs)
File "~/pytorch2/experimental/hooks_on_aux_loss.py", line 68, in calc_hessian_trace
_hessian = jacrev(jacrev(output, argnums=(1)), argnums=(1))(params, x)
File "~/anaconda3/envs/pytorch2/lib/python3.10/site-packages/torch/_functorch/eager_transforms.py", line 489, in wrapper_fn
vjp_out = _vjp_with_argnums(func, *args, argnums=argnums, has_aux=has_aux)
File "~/anaconda3/envs/pytorch2/lib/python3.10/site-packages/torch/_functorch/vmap.py", line 39, in fn
return f(*args, **kwargs)
File "~/anaconda3/envs/pytorch2/lib/python3.10/site-packages/torch/_functorch/eager_transforms.py", line 291, in _vjp_with_argnums
primals_out = func(*primals)
File "~/anaconda3/envs/pytorch2/lib/python3.10/site-packages/torch/_functorch/eager_transforms.py", line 489, in wrapper_fn
vjp_out = _vjp_with_argnums(func, *args, argnums=argnums, has_aux=has_aux)
File "~/anaconda3/envs/pytorch2/lib/python3.10/site-packages/torch/_functorch/vmap.py", line 39, in fn
return f(*args, **kwargs)
File "~/anaconda3/envs/pytorch2/lib/python3.10/site-packages/torch/_functorch/eager_transforms.py", line 291, in _vjp_with_argnums
primals_out = func(*primals)
File "~/pytorch2/experimental/hooks_on_aux_loss.py", line 66, in output
return functional_call(net, params, x)
File "~/anaconda3/envs/pytorch2/lib/python3.10/site-packages/torch/_functorch/functional_call.py", line 143, in functional_call
return nn.utils.stateless._functional_call(
File "~/anaconda3/envs/pytorch2/lib/python3.10/site-packages/torch/nn/utils/stateless.py", line 262, in _functional_call
return module(*args, **kwargs)
File "~/anaconda3/envs/pytorch2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "~/pytorch2/experimental/hooks_on_aux_loss.py", line 19, in forward
x = self.fc1(x)
File "~/anaconda3/envs/pytorch2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1536, in _call_impl
args = bw_hook.setup_input_hook(args)
File "~/anaconda3/envs/pytorch2/lib/python3.10/site-packages/torch/utils/hooks.py", line 191, in setup_input_hook
res, input_idx = self._apply_on_tensors(fn, args)
File "~/anaconda3/envs/pytorch2/lib/python3.10/site-packages/torch/utils/hooks.py", line 170, in _apply_on_tensors
new_tensors = torch.nn.modules._functions.BackwardHookFunction.apply(*tensors)
File "~/anaconda3/envs/pytorch2/lib/python3.10/site-packages/torch/autograd/function.py", line 509, in apply
raise RuntimeError(
RuntimeError: In order to use an autograd.Function with functorch transforms (vmap, grad, jvp, jacrev, ...), it must override the setup_context staticmethod. For more details, please see https://pytorch.org/docs/master/notes/extending.func.html
I think this emerges because full_backward_hook is technically a torch.autograd.Function and so when you call any torch.func method it requires a setup_context in order to handle any outputs within pytorch2.0. In previous versions of pytorch, full_backward_hook methods were skipped entirely if I recall correctly.
Is there a way to add a setup_context to a full_backward_hook, or perhaps define a custom full_backward_hook with a setup_context method itself?
class BackwardHookFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, *args):
ctx.mark_non_differentiable(*[arg for arg in args if not arg.requires_grad])
return args
@staticmethod
def backward(ctx, *args):
return args
So, I'd assume that changing this function to something like,
class BackwardHookFunction(torch.autograd.Function):
@staticmethod
def forward(*args):
return args
@staticmethod
def setup_context(ctx, *args):
ctx.mark_non_differentiable(*[arg for arg in args if not arg.requires_grad])
@staticmethod
def backward(ctx, *args):
return args
Might resolve this issue? However, I have no idea if this would mess with other parts of pytorch.
Hi All,
TL;DR :bug: If your model involves using a
full_backward_hook
, computing derivatives throws a missingsetup_context
error. As a full backward hook is technically atorch.autograd.Function
, and hence requires asetup_context
method in pytorch2, which it doesn't seem to have by default.Here's a minimal reproducible example,
The resultant error (with complete stack trace) is,
I think this emerges because
full_backward_hook
is technically atorch.autograd.Function
and so when you call anytorch.func
method it requires asetup_context
in order to handle any outputs within pytorch2.0. In previous versions of pytorch,full_backward_hook
methods were skipped entirely if I recall correctly.Is there a way to add a
setup_context
to afull_backward_hook
, or perhaps define a customfull_backward_hook
with asetup_context
method itself?EDIT: After following through the stack trace, I believe I have found from where the error is emerging. In https://github.com/pytorch/pytorch/blob/main/torch/nn/modules/_functions.py,
BackwardHookFunction
is defined as atorch.autograd.Function
object, but it's in the style of pytorch1.0.So, I'd assume that changing this function to something like,
Might resolve this issue? However, I have no idea if this would mess with other parts of pytorch.