I've been trying to use memory_efficient_fusion to see if I can speed up a main bottleneck in my code, but I hit a RuntimeError. This issue continues from #1011.
The code to reproduce this is as follows.
import torch
from torch import nn
import functorch
from functorch import make_functional, vmap, jacrev, grad
from functorch.compile import memory_efficient_fusion
import time
_ = torch.manual_seed(1234)
#version info
print("PyTorch version: ", torch.__version__) #PyTorch version: 2.0.0.dev20230116
print("CUDA version: ", torch.version.cuda) #CUDA version: 11.6
print("FuncTorch version: ", functorch.__version__) #FuncTorch version: 2.0.0.dev20230116
#=============================================#
#time with torch synchronization
def sync_time() -> float:
torch.cuda.synchronize()
return time.perf_counter()
class model(nn.Module):
def __init__(self, num_inputs, num_hidden):
super(model, self).__init__()
self.num_inputs=num_inputs
self.func = nn.Tanh()
self.fc1 = nn.Linear(2, num_hidden)
self.fc2 = nn.Linear(num_hidden, num_inputs)
def forward(self, x):
"""
Takes x in [B,A] and maps it to sign/logabsdet value in Tuple([B,], [B,])
"""
x=x.unsqueeze(-1)
idx=len(x.shape) #creates args for repeat if vmap is used or not
rep=[1 for _ in range(idx)]
rep[-2] = self.num_inputs
g = x.mean(dim=(idx-2), keepdim=True).repeat(*rep)
f = torch.cat((x,g), dim=-1)
h = self.func(self.fc1(f))
mat = self.fc2(h)
sgn, logabs = torch.linalg.slogdet(mat)
return sgn, logabs
#=============================================#
B=4096 #batch
N=2 #input nodes
H=64 #number of hidden nodes
device = torch.device('cuda')
x = torch.randn(B, N, device=device) #input data
net = model(N, H) #our model
net=net.to(device)
sgn, logabs = net(x)
fnet, params = make_functional(net)
def calc_logabs(params, x):
_, logabs = fnet(params, x)
return logabs
def calc_dlogabs_dx(params, x):
dlogabs_dx = jacrev(func=calc_logabs, argnums=1)(params, x)
return dlogabs_dx, dlogabs_dx #return aux
def local_kinetic_from_log_vmap(params, x):
d2logabs_dx2, dlogabs_dx = jacrev(func=calc_dlogabs_dx, argnums=1, has_aux=True)(params, x)
_local_kinetic = -0.5*(d2logabs_dx2.diagonal(0,-2,-1).sum() + dlogabs_dx.pow(2).sum())
return _local_kinetic
#memory efficient fusion here
#with torch.jit.fuser("fuser2"): #is this needed (from functorch/issues/840)
ps_elocal = vmap(grad(local_kinetic_from_log_vmap, argnums=0), in_dims=(None, 0))
ps_elocal_fusion = memory_efficient_fusion(vmap(grad(local_kinetic_from_log_vmap, argnums=0), in_dims=(None, 0)))
t1=sync_time()
out1 = ps_elocal(params, x)
t2=sync_time()
ps_elocal_fusion(params, x) #crashes here: aten::is_same_size no batching rule
t3=sync_time()
#Compare memory_efficient_fusion on the function's walltime
print("Laplacian (standard): %4.2e (s)",t2-t1)
print("Laplacian (fusion): %4.2e (s)",t3-t2)
The traceback is as follows,
PyTorch version: 2.0.0.dev20230116
CUDA version: 11.6
FuncTorch version: 2.0.0.dev20230116
Failed to collect metadata on function, produced code may be suboptimal. Known situations this can occur are inference mode only compilation involving resize_ or prims (!schema.hasAnyAliasInfo() INTERNAL ASSERT FAILED); if your situation looks different please file a bug to PyTorch.
Traceback (most recent call last):
File "~/anaconda3/envs/nightly_20230117/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1368, in aot_wrapper_dedupe
fw_metadata, _out = run_functionalized_fw_and_collect_metadata(flat_fn)(
File "~/anaconda3/envs/nightly_20230117/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 569, in inner
flat_f_outs = f(*flat_f_args)
... lot more errors
File "~/anaconda3/envs/nightly_20230117/lib/python3.10/site-packages/torch/autograd/__init__.py", line 285, in grad
grad_outputs_ = _make_grads(t_outputs, grad_outputs_, is_grads_batched=is_grads_batched)
File "~/anaconda3/envs/nightly_20230117/lib/python3.10/site-packages/torch/autograd/__init__.py", line 53, in _make_grads
if not torch.is_same_size(out, first_grad):
RuntimeError: Batching rule not implemented for aten::is_same_size. We could not generate a fallback.
This code was created with the latest nightly version.
PyTorch version: 2.0.0.dev20230116
CUDA version: 11.6
FuncTorch version: 2.0.0.dev20230116
Hi All,
I've been trying to use
memory_efficient_fusion
to see if I can speed up a main bottleneck in my code, but I hit aRuntimeError
. This issue continues from #1011. The code to reproduce this is as follows.The traceback is as follows,
This code was created with the latest nightly version.