Closed anijain2305 closed 2 years ago
Also affect beit_base_patch16_224, crossvit_9_240, deit_base_distilled_patch16_224, vit_base_patch16_224, xcit_large_24_p8_224
beit_base_patch16_224
crossvit_9_240
deit_base_distilled_patch16_224
vit_base_patch16_224
xcit_large_24_p8_224
Repro
import torch from torch import tensor, device import torch.fx as fx from torchdynamo.testing import rand_strided from math import inf from torch.fx.experimental.proxy_tensor import make_fx # torch version: 1.14.0a0+git25725fd # torch cuda version: 11.6 # torch git version: 25725fd62448165b91647304c26d676db22b6955 # CUDA Info: # nvcc: NVIDIA (R) Cuda compiler driver # Copyright (c) 2005-2022 NVIDIA Corporation # Built on Thu_Feb_10_18:23:41_PST_2022 # Cuda compilation tools, release 11.6, V11.6.112 # Build cuda_11.6.r11.6/compiler.30978841_0 # GPU Hardware Info: # NVIDIA A100-SXM4-40GB : 8 from torch.nn import * class Repro(torch.nn.Module): def __init__(self): super().__init__() def forward(self, arg3_1, arg4_1, slice_4, slice_7, add, add_1, add_8, add_15): cat_1 = torch.ops.aten.cat.default([arg3_1, arg4_1, slice_4, slice_7, add_1, add, add_15, add_8], 1); arg3_1 = arg4_1 = slice_4 = slice_7 = add_1 = add = add_15 = add_8 = None return (cat_1,) args = [((2, 3, 352, 352), (371712, 123904, 352, 1), torch.float32, 'cuda'), ((2, 3, 352, 352), (371712, 123904, 352, 1), torch.float32, 'cuda'), ((2, 2, 352, 352), (495616, 123904, 352, 1), torch.float16, 'cuda'), ((2, 2, 352, 352), (495616, 123904, 352, 1), torch.float16, 'cuda'), ((2, 2, 352, 352), (247808, 123904, 352, 1), torch.float32, 'cuda'), ((2, 2, 352, 352), (247808, 123904, 352, 1), torch.float32, 'cuda'), ((2, 3, 352, 352), (371712, 123904, 352, 1), torch.float32, 'cuda'), ((2, 3, 352, 352), (371712, 123904, 352, 1), torch.float32, 'cuda')] args = [rand_strided(sh, st, dt, dev) for (sh, st, dt, dev) in args] mod = make_fx(Repro().to(device="cuda"))(*args) from torchinductor.compile_fx import compile_fx_inner from torchdynamo.debug_utils import same_two_models compiled = compile_fx_inner(mod, args) compiled(args)
Error
Traceback (most recent call last): File "/scratch/anijain/work/torchdynamo/torchinductor/graph.py", line 257, in call_function return lowerings[target](*args, **kwargs) File "/scratch/anijain/work/torchdynamo/torchinductor/lowering.py", line 193, in wrapped return decomp_fn(*args, **kwargs) File "/scratch/anijain/work/torchdynamo/torchinductor/lowering.py", line 752, in cat return TensorBox(ir.ConcatKernel.create(inputs, dim)) File "/scratch/anijain/work/torchdynamo/torchinductor/ir.py", line 2135, in create assert inputs[i].get_dtype() == dtype AssertionError The above exception was the direct cause of the following exception: Traceback (most recent call last): File "/scratch/anijain/work/torchdynamo/repro.py", line 43, in <module> compiled = compile_fx_inner(mod, args) File "/scratch/anijain/work/torchdynamo/torchdynamo/debug_utils.py", line 446, in debug_wrapper compiled_fn = compiler_fn(gm, example_inputs, **kwargs) File "/scratch/anijain/work/torchdynamo/torchinductor/debug.py", line 180, in inner return fn(*args, **kwargs) File "/scratch/anijain/work/env/lib/python3.9/contextlib.py", line 79, in inner return func(*args, **kwds) File "/scratch/anijain/work/torchdynamo/torchinductor/compile_fx.py", line 103, in compile_fx_inner graph.run(*example_inputs) File "/scratch/anijain/work/torchdynamo/torchdynamo/utils.py", line 76, in time_wrapper r = func(*args, **kwargs) File "/scratch/anijain/work/torchdynamo/torchinductor/graph.py", line 146, in run return super().run(*args) File "/scratch/anijain/work/pytorch/torch/fx/interpreter.py", line 130, in run self.env[node] = self.run_node(node) File "/scratch/anijain/work/torchdynamo/torchinductor/graph.py", line 314, in run_node result = super().run_node(n) File "/scratch/anijain/work/pytorch/torch/fx/interpreter.py", line 171, in run_node return getattr(self, n.op)(n.target, args, kwargs) File "/scratch/anijain/work/torchdynamo/torchinductor/graph.py", line 259, in call_function raise LoweringException(e, target, args, kwargs) from e torchinductor.exc.LoweringException: AssertionError: target: aten.cat.default args[0]: [TensorBox(StorageBox( InputBuffer(name='arg0_1', layout=FixedLayout('cuda', torch.float32, size=[s0, s1, s2, s2], stride=[s1*s2**2, s2**2, s2, 1])) )), TensorBox(StorageBox( InputBuffer(name='arg1_1', layout=FixedLayout('cuda', torch.float32, size=[s0, s1, s2, s2], stride=[s1*s2**2, s2**2, s2, 1])) )), TensorBox(StorageBox( InputBuffer(name='arg2_1', layout=FixedLayout('cuda', torch.float16, size=[s0, s0, s2, s2], stride=[s3, s2**2, s2, 1])) )), TensorBox(StorageBox( InputBuffer(name='arg3_1', layout=FixedLayout('cuda', torch.float16, size=[s0, s0, s2, s2], stride=[s3, s2**2, s2, 1])) )), TensorBox(StorageBox( InputBuffer(name='arg5_1', layout=FixedLayout('cuda', torch.float32, size=[s0, s0, s2, s2], stride=[s0*s2**2, s2**2, s2, 1])) )), TensorBox(StorageBox( InputBuffer(name='arg4_1', layout=FixedLayout('cuda', torch.float32, size=[s0, s0, s2, s2], stride=[s0*s2**2, s2**2, s2, 1])) )), TensorBox(StorageBox( InputBuffer(name='arg7_1', layout=FixedLayout('cuda', torch.float32, size=[s0, s1, s2, s2], stride=[s1*s2**2, s2**2, s2, 1])) )), TensorBox(StorageBox( InputBuffer(name='arg6_1', layout=FixedLayout('cuda', torch.float32, size=[s0, s1, s2, s2], stride=[s1*s2**2, s2**2, s2, 1])) ))] args[1]: 1 While executing %cat : [#users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%arg0_1, %arg1_1, %arg2_1, %arg3_1, %arg5_1, %arg4_1, %arg7_1, %arg6_1], 1), kwargs = {}) Original traceback: None
This might be resolved by https://github.com/pytorch/torchdynamo/pull/1614?
I have this commit in my branch. So, maybe some corner case is still missing.
Also affect
beit_base_patch16_224
,crossvit_9_240
,deit_base_distilled_patch16_224
,vit_base_patch16_224
,xcit_large_24_p8_224
Repro
Error