pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.38k stars 102 forks source link

RuntimeError: Batching rule not implemented for aten::is_same_size. We could not generate a fallback. #1104

Open AlphaBetaGamma96 opened 1 year ago

AlphaBetaGamma96 commented 1 year ago

Hi All,

I've been trying to use memory_efficient_fusion to see if I can speed up a main bottleneck in my code, but I hit a RuntimeError. This issue continues from #1011. The code to reproduce this is as follows.

import torch
from torch import nn

import functorch
from functorch import make_functional, vmap, jacrev, grad
from functorch.compile import memory_efficient_fusion

import time

_ = torch.manual_seed(1234)

#version info
print("PyTorch version:   ", torch.__version__)     #PyTorch version:    2.0.0.dev20230116
print("CUDA version:      ", torch.version.cuda)    #CUDA version:       11.6
print("FuncTorch version: ", functorch.__version__) #FuncTorch version:  2.0.0.dev20230116

#=============================================#

#time with torch synchronization
def sync_time() -> float:
  torch.cuda.synchronize()
  return time.perf_counter()

class model(nn.Module):

  def __init__(self, num_inputs, num_hidden):
    super(model, self).__init__()

    self.num_inputs=num_inputs
    self.func = nn.Tanh()

    self.fc1 = nn.Linear(2, num_hidden)
    self.fc2 = nn.Linear(num_hidden, num_inputs)

  def forward(self, x):
    """
    Takes x in [B,A] and maps it to sign/logabsdet value in Tuple([B,], [B,])
    """
    x=x.unsqueeze(-1)
    idx=len(x.shape)             #creates args for repeat if vmap is used or not
    rep=[1 for _ in range(idx)]
    rep[-2] = self.num_inputs
    g = x.mean(dim=(idx-2), keepdim=True).repeat(*rep)
    f = torch.cat((x,g), dim=-1)

    h = self.func(self.fc1(f))

    mat = self.fc2(h)
    sgn, logabs = torch.linalg.slogdet(mat)
    return sgn, logabs

#=============================================#

B=4096 #batch
N=2    #input nodes
H=64   #number of hidden nodes
device = torch.device('cuda')

x = torch.randn(B, N, device=device) #input data

net = model(N, H) #our model
net=net.to(device)

sgn, logabs = net(x)

fnet, params = make_functional(net)

def calc_logabs(params, x):
  _, logabs = fnet(params, x)
  return logabs

def calc_dlogabs_dx(params, x):
  dlogabs_dx = jacrev(func=calc_logabs, argnums=1)(params, x)
  return dlogabs_dx, dlogabs_dx #return aux

def local_kinetic_from_log_vmap(params, x):
  d2logabs_dx2, dlogabs_dx = jacrev(func=calc_dlogabs_dx, argnums=1, has_aux=True)(params, x)
  _local_kinetic = -0.5*(d2logabs_dx2.diagonal(0,-2,-1).sum() + dlogabs_dx.pow(2).sum())
  return _local_kinetic 

#memory efficient fusion here
#with torch.jit.fuser("fuser2"): #is this needed (from functorch/issues/840)
ps_elocal = vmap(grad(local_kinetic_from_log_vmap, argnums=0), in_dims=(None, 0))
ps_elocal_fusion = memory_efficient_fusion(vmap(grad(local_kinetic_from_log_vmap, argnums=0), in_dims=(None, 0)))

t1=sync_time()

out1 = ps_elocal(params, x)

t2=sync_time()

ps_elocal_fusion(params, x) #crashes here: aten::is_same_size no batching rule

t3=sync_time()

#Compare memory_efficient_fusion on the function's walltime
print("Laplacian (standard): %4.2e (s)",t2-t1)
print("Laplacian (fusion):   %4.2e (s)",t3-t2)

The traceback is as follows,

PyTorch version:    2.0.0.dev20230116
CUDA version:       11.6
FuncTorch version:  2.0.0.dev20230116
Failed to collect metadata on function, produced code may be suboptimal.  Known situations this can occur are inference mode only compilation involving resize_ or prims (!schema.hasAnyAliasInfo() INTERNAL ASSERT FAILED); if your situation looks different please file a bug to PyTorch.
Traceback (most recent call last):
  File "~/anaconda3/envs/nightly_20230117/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1368, in aot_wrapper_dedupe
    fw_metadata, _out = run_functionalized_fw_and_collect_metadata(flat_fn)(
  File "~/anaconda3/envs/nightly_20230117/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 569, in inner
    flat_f_outs = f(*flat_f_args)

    ... lot more errors

  File "~/anaconda3/envs/nightly_20230117/lib/python3.10/site-packages/torch/autograd/__init__.py", line 285, in grad
    grad_outputs_ = _make_grads(t_outputs, grad_outputs_, is_grads_batched=is_grads_batched)
  File "~/anaconda3/envs/nightly_20230117/lib/python3.10/site-packages/torch/autograd/__init__.py", line 53, in _make_grads
    if not torch.is_same_size(out, first_grad):
RuntimeError: Batching rule not implemented for aten::is_same_size. We could not generate a fallback.

This code was created with the latest nightly version.

PyTorch version:    2.0.0.dev20230116
CUDA version:       11.6
FuncTorch version:  2.0.0.dev20230116