pytorch / torchdynamo

A Python-level JIT compiler designed to make unmodified PyTorch programs faster.
BSD 3-Clause "New" or "Revised" License
1.01k stars 124 forks source link

[inductor] index_put - XLNetLMHeadModel #1356

Closed anijain2305 closed 2 years ago

anijain2305 commented 2 years ago

benchmarks/huggingface.py --training -dcuda --accuracy --training --inductor --only=XLNetLMHeadModel

Error

RuntimeError: Overloaded torch operator invoked from Python failed to many any schema:
aten::index_put_() Expected a value of type 'List[Optional[Tensor]]' for argument 'indices' but instead found type 'tuple'.
Position: 1
Value: (slice(None, None, None), slice(None, None, None), slice(None, None, None), FakeTensor(FakeTensor(..., device='meta', size=(512,), dtype=torch.int64), cuda:0))
Declaration: aten::index_put_(Tensor(a!) self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor(a!)
Cast error details: Unable to cast slice(None, None, None) to Tensor

aten::index_put_() Expected a value of type 'List[Tensor]' for argument 'indices' but instead found type 'tuple'.
Position: 1
Value: (slice(None, None, None), slice(None, None, None), slice(None, None, None), FakeTensor(FakeTensor(..., device='meta', size=(512,), dtype=torch.int64), cuda:0))
Declaration: aten::index_put_.hacked_twin(Tensor(a!) self, Tensor[] indices, Tensor values, bool accumulate=False) -> Tensor(a!)
Cast error details: Unable to cast Python instance to C++ type (#define PYBIND11_DETAILED_ERROR_MESSAGES or compile in debug mode for details)

Repro


import torch
import torchdynamo
from torch import tensor, device
import torch.fx as fx
from torchdynamo.testing import rand_strided
from math import inf
from torchdynamo.debug_utils import run_fwd_maybe_bwd

args = [((512, 1, 16, 64), (1024, 1024, 64, 1), torch.float32, 'cuda', True), ((1024, 1, 16, 64), (1024, 1024, 64, 1), torch.float32, 'cuda', True), ((16, 64), (64, 1), torch.float32, 'cuda', True), ((512,), (1,), torch.int64, 'cuda', False)]
args = [rand_strided(sh, st, dt, dev).requires_grad_(rg) for (sh, st, dt, dev, rg) in args]

from torch.nn import *
class Repro(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, einsum_9, einsum_12, self_self_transformer_layer_1__rel_attn_r_r_bias, arange_3):
        add_7 = einsum_9 + self_self_transformer_layer_1__rel_attn_r_r_bias;  einsum_9 = self_self_transformer_layer_1__rel_attn_r_r_bias = None
        einsum_14 = torch.functional.einsum('ibnd,jbnd->bnij', add_7, einsum_12);  add_7 = einsum_12 = None
        reshape_2 = einsum_14.reshape(1, 16, 1024, 512);  einsum_14 = None
        getitem_4 = reshape_2[(slice(None, None, None), slice(None, None, None), slice(1, None, None), slice(None, None, None))];  reshape_2 = None
        reshape_3 = getitem_4.reshape(1, 16, 512, 1023);  getitem_4 = None
        index_select_1 = torch.index_select(reshape_3, 3, arange_3);  reshape_3 = arange_3 = None
        return (index_select_1,)

mod = Repro().cuda()
opt_mod = torchdynamo.optimize("aot_inductor_debug")(mod)

with torch.cuda.amp.autocast(enabled=False):
    ref = run_fwd_maybe_bwd(mod, args)
    res = run_fwd_maybe_bwd(opt_mod, args)
anijain2305 commented 2 years ago

@voznesenskym Assigning to you arbitrarily. Let me know if thats ok.

anijain2305 commented 2 years ago

@yanboliang Assigning to you. on first glimpse, this seems to be the decomp issue. Can you please take a look.

yanboliang commented 2 years ago

I have identified the root cause, actually this is a PyTorch core bug.

The failure is caused by this line, because torch.index_put_ does not support slice(None)s as indices. I also verified by running PT unit tests, as all index_add_ ops tests are using dim = 0 where this problem can't be exposed. Inductor removed lowering for index_select and added decomp for index_{add,add_} at https://github.com/pytorch/torchdynamo/pull/1292, I tried to revert that change and found it will fix this bug. I think if we can easily support slice(None)s, we should fix this inside of PyTorch; otherwise, remove decomp for these ops?

Thanks @SherlockNoMad for helping me navigate and find the root cause. cc @jansel @ngimel @lezcano Any suggestion?

lezcano commented 2 years ago

I believe changing that line from slice(None) to None should fix the issue. Could you confirm?

lezcano commented 2 years ago

I put up a fix: https://github.com/pytorch/pytorch/pull/86266

ngimel commented 2 years ago

Thanks @lezcano for the fix, let's not remove the decomps and enable tests when pytorch pin is updated. @yanboliang can you please add tests to inductor that would expose this?

yanboliang commented 2 years ago

@lezcano Thanks for your PR, I verified it fixed my issue! @ngimel I'll add tests to inductor when pytorch pin is updated.