pytorch / torchdynamo

A Python-level JIT compiler designed to make unmodified PyTorch programs faster.
BSD 3-Clause "New" or "Revised" License
997 stars 123 forks source link

[aot_eager] accuracy failure for jx_nest_base #1286

Closed anijain2305 closed 1 year ago

anijain2305 commented 2 years ago

Repro


import torch
import torchdynamo
from torch import tensor, device
import torch.fx as fx
from torchdynamo.testing import rand_strided
from math import inf
from torchdynamo.debug_utils import run_fwd_maybe_bwd
from torchdynamo.debug_utils import same_two_models

args = [((2, 1, 196, 512), (100352, 100352, 512, 1), torch.float32, 'cuda', True), ((2, 1, 196, 2048), (401408, 401408, 2048, 1), torch.float32, 'cuda', True)]
args = [rand_strided(sh, st, dt, dev).requires_grad_(rg) for (sh, st, dt, dev, rg) in args]

from torch.nn import *
class Repro(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.self_levels_2__transformer_encoder_19__mlp_fc2 = Linear(in_features=2048, out_features=512, bias=True)
        self.self_levels_2__transformer_encoder_19__mlp_drop2 = Dropout(p=0.0, inplace=False)
        self.self_norm = LayerNorm((512,), eps=1e-06, elementwise_affine=True)
        self.self_global_pool_pool = AdaptiveAvgPool2d(output_size=1)
        self.self_global_pool_flatten = Flatten(start_dim=1, end_dim=-1)
        self.self_head = Linear(in_features=512, out_features=1000, bias=True)

    def forward(self, add_49, self_levels_2__transformer_encoder_19__mlp_drop1):
        self_levels_2__transformer_encoder_19__mlp_fc2 = self.self_levels_2__transformer_encoder_19__mlp_fc2(self_levels_2__transformer_encoder_19__mlp_drop1);  self_levels_2__transformer_encoder_19__mlp_drop1 = None
        self_levels_2__transformer_encoder_19__mlp_drop2 = self.self_levels_2__transformer_encoder_19__mlp_drop2(self_levels_2__transformer_encoder_19__mlp_fc2);  self_levels_2__transformer_encoder_19__mlp_fc2 = None
        add_50 = add_49 + self_levels_2__transformer_encoder_19__mlp_drop2;  add_49 = self_levels_2__transformer_encoder_19__mlp_drop2 = None
        reshape_58 = add_50.reshape(2, 1, 1, 14, 14, 512);  add_50 = None
        transpose_29 = reshape_58.transpose(2, 3);  reshape_58 = None
        reshape_59 = transpose_29.reshape(2, 14, 14, 512);  transpose_29 = None
        permute_57 = reshape_59.permute(0, 3, 1, 2);  reshape_59 = None
        permute_58 = permute_57.permute(0, 2, 3, 1);  permute_57 = None
        self_norm = self.self_norm(permute_58);  permute_58 = None
        permute_59 = self_norm.permute(0, 3, 1, 2);  self_norm = None
        self_global_pool_pool = self.self_global_pool_pool(permute_59);  permute_59 = None
        self_global_pool_flatten = self.self_global_pool_flatten(self_global_pool_pool);  self_global_pool_pool = None
        self_head = self.self_head(self_global_pool_flatten);  self_global_pool_flatten = None
        return (self_head,)

mod = Repro().cuda()
opt_mod = torchdynamo.optimize("aot_eager")(mod)

mod.eval()
opt_mod.eval()

with torch.cuda.amp.autocast(enabled=False):
    assert same_two_models(mod, mod, args), "Eager itself failed"
    assert same_two_models(mod, opt_mod, args), "Dynamo failed"
anijain2305 commented 2 years ago

cc @eellison

eellison commented 2 years ago

The root cause here is because layer_norm and layer_norm_backward decomps do incorrect stride propagation. (working on cleaning up cross ref pr and submitting).

CC @ngimel - was a fix for this in the works or I am making that up?

ngimel commented 2 years ago

layer_norm should be fixed by https://github.com/pytorch/pytorch/pull/84799, backward probably is still not respecting eager strides.

anijain2305 commented 1 year ago

I think aot_eager does not really go through decomps.

eellison commented 1 year ago

@anijain2305 its used for fake tensor stride propagation

eellison commented 1 year ago

Fix is in https://github.com/pytorch/pytorch/pull/85417