pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.38k stars 102 forks source link

vmap fails if your model includes full_backward_hook in pytorch2.0 #1124

Closed AlphaBetaGamma96 closed 1 year ago

AlphaBetaGamma96 commented 1 year ago

Hi All,

TL;DR :bug: If your model involves using a full_backward_hook, computing derivatives throws a missing setup_context error. As a full backward hook is technically a torch.autograd.Function, and hence requires a setup_context method in pytorch2, which it doesn't seem to have by default.

Here's a minimal reproducible example,

import torch
from torch import nn, Tensor
from torch.func import vmap, jacrev, functional_call

from typing import Tuple
from collections import defaultdict

class model(nn.Module):

  def __init__(self, num_input, num_hidden):
    super(model, self).__init__()

    self.fc1 = nn.Linear(num_input, num_hidden)
    self.fc2 = nn.Linear(num_hidden, 1)

    self.act_func = nn.Tanh()

  def forward(self, x):
    x = self.fc1(x)
    x = self.act_func(x)
    x = self.fc2(x)
    return x

num_samples = 4096
num_input=2
num_hidden=64

device=torch.device("cpu")

state = defaultdict(dict) #equivalent to optim.state

def forward_pre_hook(module: nn.Module, input: Tuple[Tensor]) -> None:
  a=input[0]
  if(module.bias is not None):
    _shape = [*a.shape]
    _shape[-1]=1
    ones=torch.ones(*_shape, device=a.device, dtype=a.dtype)
    a = torch.cat([a, ones], dim=-1)
  a = a - torch.mean(a, keepdim=True, dim=0)
  state[module]['a'] = a

def full_backward_hook(module: nn.Module, grad_input: Tuple[Tensor], grad_output: Tuple[Tensor]) -> None:
  e = grad_output[0] * grad_output[0].size(0)
  e = e - torch.mean(e, keepdim=True, dim=0)
  state[module]['e'] = e 

#The input
x = torch.randn(num_samples, num_input, device=device)

#Our model
net = model(num_input=num_input,
            num_hidden=num_hidden)

#Add the hooks
for mod in net.modules():
  if(mod.__class__.__name__ == 'Linear'):
    mod.register_forward_pre_hook(forward_pre_hook)
    mod.register_full_backward_hook(full_backward_hook)

y = net(x) #compute output

#Compute trace of Hessian
def calc_hessian_trace(params, x):

  def output(params, x):
    return functional_call(net, params, x)

  _hessian = jacrev(jacrev(output, argnums=(1)), argnums=(1))(params, x)
  return _hessian.diagonal(0,-2,-1).sum(-1)

laplacian = vmap(calc_hessian_trace, in_dims=(None, 0))(dict(net.named_parameters()), x) #fails 

The resultant error (with complete stack trace) is,

Traceback (most recent call last):
  File "~/pytorch2/experimental/hooks_on_aux_loss.py", line 71, in <module>
    laplacian = vmap(calc_hessian_trace, in_dims=(None, 0))(dict(net.named_parameters()), x) #fails 
  File "~/anaconda3/envs/pytorch2/lib/python3.10/site-packages/torch/_functorch/vmap.py", line 434, in wrapped
    return _flat_vmap(
  File "~/anaconda3/envs/pytorch2/lib/python3.10/site-packages/torch/_functorch/vmap.py", line 39, in fn
    return f(*args, **kwargs)
  File "~/anaconda3/envs/pytorch2/lib/python3.10/site-packages/torch/_functorch/vmap.py", line 619, in _flat_vmap
    batched_outputs = func(*batched_inputs, **kwargs)
  File "~/pytorch2/experimental/hooks_on_aux_loss.py", line 68, in calc_hessian_trace
    _hessian = jacrev(jacrev(output, argnums=(1)), argnums=(1))(params, x)
  File "~/anaconda3/envs/pytorch2/lib/python3.10/site-packages/torch/_functorch/eager_transforms.py", line 489, in wrapper_fn
    vjp_out = _vjp_with_argnums(func, *args, argnums=argnums, has_aux=has_aux)
  File "~/anaconda3/envs/pytorch2/lib/python3.10/site-packages/torch/_functorch/vmap.py", line 39, in fn
    return f(*args, **kwargs)
  File "~/anaconda3/envs/pytorch2/lib/python3.10/site-packages/torch/_functorch/eager_transforms.py", line 291, in _vjp_with_argnums
    primals_out = func(*primals)
  File "~/anaconda3/envs/pytorch2/lib/python3.10/site-packages/torch/_functorch/eager_transforms.py", line 489, in wrapper_fn
    vjp_out = _vjp_with_argnums(func, *args, argnums=argnums, has_aux=has_aux)
  File "~/anaconda3/envs/pytorch2/lib/python3.10/site-packages/torch/_functorch/vmap.py", line 39, in fn
    return f(*args, **kwargs)
  File "~/anaconda3/envs/pytorch2/lib/python3.10/site-packages/torch/_functorch/eager_transforms.py", line 291, in _vjp_with_argnums
    primals_out = func(*primals)
  File "~/pytorch2/experimental/hooks_on_aux_loss.py", line 66, in output
    return functional_call(net, params, x)
  File "~/anaconda3/envs/pytorch2/lib/python3.10/site-packages/torch/_functorch/functional_call.py", line 143, in functional_call
    return nn.utils.stateless._functional_call(
  File "~/anaconda3/envs/pytorch2/lib/python3.10/site-packages/torch/nn/utils/stateless.py", line 262, in _functional_call
    return module(*args, **kwargs)
  File "~/anaconda3/envs/pytorch2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "~/pytorch2/experimental/hooks_on_aux_loss.py", line 19, in forward
    x = self.fc1(x)
  File "~/anaconda3/envs/pytorch2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1536, in _call_impl
    args = bw_hook.setup_input_hook(args)
  File "~/anaconda3/envs/pytorch2/lib/python3.10/site-packages/torch/utils/hooks.py", line 191, in setup_input_hook
    res, input_idx = self._apply_on_tensors(fn, args)
  File "~/anaconda3/envs/pytorch2/lib/python3.10/site-packages/torch/utils/hooks.py", line 170, in _apply_on_tensors
    new_tensors = torch.nn.modules._functions.BackwardHookFunction.apply(*tensors)
  File "~/anaconda3/envs/pytorch2/lib/python3.10/site-packages/torch/autograd/function.py", line 509, in apply
    raise RuntimeError(
RuntimeError: In order to use an autograd.Function with functorch transforms (vmap, grad, jvp, jacrev, ...), it must override the setup_context staticmethod. For more details, please see https://pytorch.org/docs/master/notes/extending.func.html

I think this emerges because full_backward_hook is technically a torch.autograd.Function and so when you call any torch.func method it requires a setup_context in order to handle any outputs within pytorch2.0. In previous versions of pytorch, full_backward_hook methods were skipped entirely if I recall correctly.

Is there a way to add a setup_context to a full_backward_hook, or perhaps define a custom full_backward_hook with a setup_context method itself?

EDIT: After following through the stack trace, I believe I have found from where the error is emerging. In https://github.com/pytorch/pytorch/blob/main/torch/nn/modules/_functions.py, BackwardHookFunction is defined as a torch.autograd.Function object, but it's in the style of pytorch1.0.

class BackwardHookFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, *args):
        ctx.mark_non_differentiable(*[arg for arg in args if not arg.requires_grad])
        return args

    @staticmethod
    def backward(ctx, *args):
        return args

So, I'd assume that changing this function to something like,

class BackwardHookFunction(torch.autograd.Function):
    @staticmethod
    def forward(*args):
        return args

    @staticmethod
    def setup_context(ctx, *args):
        ctx.mark_non_differentiable(*[arg for arg in args if not arg.requires_grad])

    @staticmethod
    def backward(ctx, *args):
        return args

Might resolve this issue? However, I have no idea if this would mess with other parts of pytorch.

zou3519 commented 1 year ago

Closing as transfered to pytorch via https://github.com/pytorch/pytorch/issues/99556