pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.4k stars 102 forks source link

AOT Autograd fails to get correct grads for view and Inplace Relu #514

Open anijain2305 opened 2 years ago

anijain2305 commented 2 years ago

While working on TorchDynamo + AOT integration, I came across the following bug

import torch
from torch.nn import *
from functorch.compile import print_compile, aot_module
import copy

class Bar(torch.nn.Module):
    def __init__(self):
        super().__init__()
        # self.conv = Conv2d(3, 2, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), bias=False)
        # self.instance_norm = InstanceNorm2d(2, affine=True, track_running_stats=True)
        self.relu = ReLU(inplace=True)

    def forward(self, x : torch.Tensor):
        # self_main_0 = self.conv(x)
        # self_main_1 = self.instance_norm(self_main_0)
        self_main_0 = x * 2
        self_main_1 = self_main_0.view([1, 3, 128, 128])
        self_main_2 = self.relu(self_main_1)
        return self_main_2

mod = Bar().to(device="cuda")
# Reduce randomness bits
mod.eval()

inp0 = torch.randn(1, 3, 128, 128, device='cuda', requires_grad=True)
inputs = (inp0, )

cloned_inp0 = inp0.clone().detach().requires_grad_(True)
cloned_inputs = (cloned_inp0, )

# Reference calculation
mod.zero_grad()
duplicated_mod = copy.deepcopy(mod)
ref = duplicated_mod(*inputs)
ref.sum().backward()
ref_grads = []
for param in duplicated_mod.parameters():
    ref_grads.append(param.grad)

# AOT stuff
fx_mod = torch.fx.symbolic_trace(mod)
aot_mod = aot_module(fx_mod, print_compile)
aot_mod.zero_grad()
with torch.jit.fuser("fuser2"):
    res = aot_mod(*cloned_inputs)
    res.sum().backward()

res_grads = []
for param in aot_mod.parameters():
    res_grads.append(param.grad)

assert torch.allclose(ref, res)

for (a, b) in zip(ref_grads, res_grads):
    assert torch.allclose(a, b, atol=1e-4, rtol=1e-4), print(a, b)

for (a, b) in zip(inputs, cloned_inputs):
    assert torch.allclose(a.grad, b.grad, atol=1e-4, rtol=1e-4), print(a.grad, b.grad)

view + inplace_Relu seems to give wrong backward trace.

@Chillee @jansel

Chillee commented 2 years ago

Do you need the instance norm/batch norm stuff here?

anijain2305 commented 2 years ago

Not necessarily. I kept it there because initially the graph was larger and consisted instance_norm. I minimized the example by looking at the FX graphs to make it easier.

anijain2305 commented 2 years ago

Another issue - pytorch_struct

import torch
from torch.nn import *
import functorch
from functorch.compile import memory_efficient_fusion

class FxModule(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, einsum, unsqueeze_3):
        exp = einsum.exp()
        gather = torch.gather(exp, 3, unsqueeze_3);  expand = unsqueeze_3 = None
        return gather

inp0 = torch.randn([1, 3, 2, 10], device='cuda', requires_grad=True)
inp1 = torch.ones([1, 3, 2, 1], dtype=torch.int64, device='cuda')
inps = [inp0, inp1]

cloned_inps = [x.clone().detach() for x in inps]
cloned_inps[0].requires_grad_(True)
cloned_inps[0].grad = None

mod = FxModule().to(device="cuda")
ref = mod(*inps)
ref.sum().backward()

aot_mod = memory_efficient_fusion(mod)
res = aot_mod(*cloned_inps)
res.sum().backward()

assert torch.allclose(ref, res)
print(inps[0].grad)
print(cloned_inps[0].grad)
assert torch.allclose(inps[0].grad, cloned_inps[0].grad)

print("Success")
anijain2305 commented 2 years ago

Both these cases require special/hacky handling in AOT Autograd if they have to be supported quickly.

A better approach is the functionalization pass. CC'ing @bdhirsh to try functionalization on these two test cases as well.