pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.38k stars 102 forks source link

memory_efficient_fusion leads to RuntimeError for higher-order gradients calculation. RuntimeError: You are attempting to call Tensor.requires_grad_() #1011

Open AlphaBetaGamma96 opened 1 year ago

AlphaBetaGamma96 commented 1 year ago

Hi All,

I've tried improving the speed of my code via using memory_efficient_fusion, however, it leads to Tensor.requires_grad_() error and I have no idea why. The error is as follows,

RuntimeError: You are attempting to call Tensor.requires_grad_() (or perhaps using torch.autograd.functional.* APIs) inside of a function being transformed by a functorch transform. This is unsupported, please attempt to use the functorch transforms (e.g. grad, vjp, jacrev, jacfwd, hessian) or call requires_grad_() outside of a function being transformed instead.

I've attached a 'minimal' reproducible example of this behaviour below. I've tried a few different things but nothing's seems to have worked. I did see in #840 memory_efficient_fusion is done within a context manager, however, when using that I get the same error.

Thanks in advance!

EDIT: When I tried running it, it tried to use the networkx package but that wasn't installed by default. So, I had to manually install that (which wasn't a problem), just not sure if installing from source should also include install those packages as well!

import torch
from torch import nn

import functorch
from functorch import make_functional, vmap, jacrev, grad
from functorch.compile import memory_efficient_fusion

import time

_ = torch.manual_seed(1234)

#version info
print("PyTorch version:   ", torch.__version__)
print("CUDA version:      ", torch.version.cuda)
print("FuncTorch version: ", functorch.__version__)

#=============================================#

#time with torch synchronization
def sync_time() -> float:
  torch.cuda.synchronize()
  return time.perf_counter()

class model(nn.Module):

  def __init__(self, num_inputs, num_hidden):
    super(model, self).__init__()

    self.num_inputs=num_inputs
    self.func = nn.Tanh()

    self.fc1 = nn.Linear(2, num_hidden)
    self.fc2 = nn.Linear(num_hidden, num_inputs)

  def forward(self, x):
    """
    Takes x in [B,A,1] and maps it to sign/logabsdet value in Tuple([B,], [B,])
    """

    idx=len(x.shape)             #creates args for repeat if vmap is used or not
    rep=[1 for _ in range(idx)]
    rep[-2] = self.num_inputs
    g = x.mean(dim=(idx-2), keepdim=True).repeat(*rep)
    f = torch.cat((x,g), dim=-1)

    h = self.func(self.fc1(f))

    mat = self.fc2(h)
    sgn, logabs = torch.linalg.slogdet(mat)
    return sgn, logabs

#=============================================#

B=4096 #batch
N=2    #input nodes
H=64   #number of hidden nodes
device = torch.device('cuda')

x = torch.randn(B, N, 1, device=device) #input data

net = model(N, H) #our model
net=net.to(device)

fnet, params = make_functional(net)

def calc_logabs(params, x):
  _, logabs = fnet(params, x)
  return logabs

def calc_dlogabs_dx(params, x):
  dlogabs_dx = jacrev(func=calc_logabs, argnums=1)(params, x)
  return dlogabs_dx, dlogabs_dx #return aux

def local_kinetic_from_log_vmap(params, x):
  d2logabs_dx2, dlogabs_dx = jacrev(func=calc_dlogabs_dx, argnums=1, has_aux=True)(params, x)
  _local_kinetic = -0.5*(d2logabs_dx2.diagonal(0,-4,-2).sum() + dlogabs_dx.pow(2).sum())
  return _local_kinetic 

#memory efficient fusion here
#with torch.jit.fuser("fuser2"): is this needed (from functorch/issues/840)
ps_elocal = grad(local_kinetic_from_log_vmap, argnums=0)
ps_elocal_fusion = memory_efficient_fusion(grad(local_kinetic_from_log_vmap, argnums=0))

#ps_elocal_fusion(params, x) #no vmap attempt (throws size mis-match error)

t1=sync_time()

vmap(ps_elocal, in_dims=(None, 0))(params, x) #works fine 

t2=sync_time()

vmap(ps_elocal_fusion, in_dims=(None, 0))(params, x) #error (crashes on this line)

t3=sync_time()

print("Laplacian (standard): %4.2e (s)",t2-t1)
print("Laplacian (fusion):   %4.2e (s)",t3-t2)
AlphaBetaGamma96 commented 1 year ago

So it seems the solution is to place the vmap called within the memory_efficient_fusion call like so,

ps_elocal_fusion = memory_efficient_fusion(vmap(grad(local_kinetic_from_log_vmap, argnums=0), in_dims=(None, 0)))

then just call,

ps_elocal_fusion(params, x) #works now.

Although, it's about an order of magnitude slower than the non-memory_efficient_fusion version.

ps_elocal = vmap(grad(local_kinetic_from_log_vmap, argnums=0), in_dims=(None, 0)) #0.454 (s)
ps_elocal_fusion = memory_efficient_fusion(vmap(grad(local_kinetic_from_log_vmap, argnums=0), in_dims=(None, 0))) #5.804 (s)

EDIT: versions for reference;

PyTorch version:    1.13.0.dev20220820
CUDA version:       11.6
FuncTorch version:  0.3.0a0+86a9049
samdow commented 1 year ago

cc @Chillee @anijain2305

Any thoughts? In particular re: why memory_efficient_fusion made the the final case slower

AlphaBetaGamma96 commented 1 year ago

I thought I'd also mention that memory_efficient_fusion fails if a scalar is included. For example, using this function (which differs from the original value of -0.5 * factor)

def local_kinetic_from_log_vmap(params, x):
  d2logabs_dx2, dlogabs_dx = jacrev(func=calc_dlogabs_dx, argnums=1, has_aux=True)(params, x)
  _local_kinetic = -0.5*(d2logabs_dx2.diagonal(0,-4,-2).sum() + dlogabs_dx.pow(2).sum())
  return _local_kinetic 

returns the following error,

RuntimeError: aten::_to_copy() Expected a value of type 'Tensor' for argument 'self' but instead found type 'float'.
Position: 0
Value: -0.5
Declaration: aten::_to_copy(Tensor self, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, int? memory_format=None) -> Tensor
Cast error details: Unable to cast -0.5 to Tensor