Open anijain2305 opened 2 years ago
Do you need the instance norm/batch norm stuff here?
Not necessarily. I kept it there because initially the graph was larger and consisted instance_norm
. I minimized the example by looking at the FX graphs to make it easier.
Another issue - pytorch_struct
import torch
from torch.nn import *
import functorch
from functorch.compile import memory_efficient_fusion
class FxModule(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, einsum, unsqueeze_3):
exp = einsum.exp()
gather = torch.gather(exp, 3, unsqueeze_3); expand = unsqueeze_3 = None
return gather
inp0 = torch.randn([1, 3, 2, 10], device='cuda', requires_grad=True)
inp1 = torch.ones([1, 3, 2, 1], dtype=torch.int64, device='cuda')
inps = [inp0, inp1]
cloned_inps = [x.clone().detach() for x in inps]
cloned_inps[0].requires_grad_(True)
cloned_inps[0].grad = None
mod = FxModule().to(device="cuda")
ref = mod(*inps)
ref.sum().backward()
aot_mod = memory_efficient_fusion(mod)
res = aot_mod(*cloned_inps)
res.sum().backward()
assert torch.allclose(ref, res)
print(inps[0].grad)
print(cloned_inps[0].grad)
assert torch.allclose(inps[0].grad, cloned_inps[0].grad)
print("Success")
Both these cases require special/hacky handling in AOT Autograd if they have to be supported quickly.
A better approach is the functionalization pass. CC'ing @bdhirsh to try functionalization on these two test cases as well.
While working on TorchDynamo + AOT integration, I came across the following bug
view + inplace_Relu seems to give wrong backward trace.
@Chillee @jansel