pytorch / torchdynamo

A Python-level JIT compiler designed to make unmodified PyTorch programs faster.
BSD 3-Clause "New" or "Revised" License
1.01k stars 124 forks source link

[inductor] `cat` - Super_slomo, dlrm, timm_vision_transformer - AMP #1648

Closed anijain2305 closed 2 years ago

anijain2305 commented 2 years ago

Also affect beit_base_patch16_224, crossvit_9_240, deit_base_distilled_patch16_224, vit_base_patch16_224, xcit_large_24_p8_224

Repro


import torch
from torch import tensor, device
import torch.fx as fx
from torchdynamo.testing import rand_strided
from math import inf
from torch.fx.experimental.proxy_tensor import make_fx

# torch version: 1.14.0a0+git25725fd
# torch cuda version: 11.6
# torch git version: 25725fd62448165b91647304c26d676db22b6955

# CUDA Info: 
# nvcc: NVIDIA (R) Cuda compiler driver 
# Copyright (c) 2005-2022 NVIDIA Corporation 
# Built on Thu_Feb_10_18:23:41_PST_2022 
# Cuda compilation tools, release 11.6, V11.6.112 
# Build cuda_11.6.r11.6/compiler.30978841_0 

# GPU Hardware Info: 
# NVIDIA A100-SXM4-40GB : 8 

from torch.nn import *
class Repro(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, arg3_1, arg4_1, slice_4, slice_7, add, add_1, add_8, add_15):
        cat_1 = torch.ops.aten.cat.default([arg3_1, arg4_1, slice_4, slice_7, add_1, add, add_15, add_8], 1);  arg3_1 = arg4_1 = slice_4 = slice_7 = add_1 = add = add_15 = add_8 = None
        return (cat_1,)

args = [((2, 3, 352, 352), (371712, 123904, 352, 1), torch.float32, 'cuda'), ((2, 3, 352, 352), (371712, 123904, 352, 1), torch.float32, 'cuda'), ((2, 2, 352, 352), (495616, 123904, 352, 1), torch.float16, 'cuda'), ((2, 2, 352, 352), (495616, 123904, 352, 1), torch.float16, 'cuda'), ((2, 2, 352, 352), (247808, 123904, 352, 1), torch.float32, 'cuda'), ((2, 2, 352, 352), (247808, 123904, 352, 1), torch.float32, 'cuda'), ((2, 3, 352, 352), (371712, 123904, 352, 1), torch.float32, 'cuda'), ((2, 3, 352, 352), (371712, 123904, 352, 1), torch.float32, 'cuda')]
args = [rand_strided(sh, st, dt, dev) for (sh, st, dt, dev) in args]
mod = make_fx(Repro().to(device="cuda"))(*args)

from torchinductor.compile_fx import compile_fx_inner
from torchdynamo.debug_utils import same_two_models

compiled = compile_fx_inner(mod, args)
compiled(args)

Error

Traceback (most recent call last):
  File "/scratch/anijain/work/torchdynamo/torchinductor/graph.py", line 257, in call_function
    return lowerings[target](*args, **kwargs)
  File "/scratch/anijain/work/torchdynamo/torchinductor/lowering.py", line 193, in wrapped
    return decomp_fn(*args, **kwargs)
  File "/scratch/anijain/work/torchdynamo/torchinductor/lowering.py", line 752, in cat
    return TensorBox(ir.ConcatKernel.create(inputs, dim))
  File "/scratch/anijain/work/torchdynamo/torchinductor/ir.py", line 2135, in create
    assert inputs[i].get_dtype() == dtype
AssertionError

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/scratch/anijain/work/torchdynamo/repro.py", line 43, in <module>
    compiled = compile_fx_inner(mod, args)
  File "/scratch/anijain/work/torchdynamo/torchdynamo/debug_utils.py", line 446, in debug_wrapper
    compiled_fn = compiler_fn(gm, example_inputs, **kwargs)
  File "/scratch/anijain/work/torchdynamo/torchinductor/debug.py", line 180, in inner
    return fn(*args, **kwargs)
  File "/scratch/anijain/work/env/lib/python3.9/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/scratch/anijain/work/torchdynamo/torchinductor/compile_fx.py", line 103, in compile_fx_inner
    graph.run(*example_inputs)
  File "/scratch/anijain/work/torchdynamo/torchdynamo/utils.py", line 76, in time_wrapper
    r = func(*args, **kwargs)
  File "/scratch/anijain/work/torchdynamo/torchinductor/graph.py", line 146, in run
    return super().run(*args)
  File "/scratch/anijain/work/pytorch/torch/fx/interpreter.py", line 130, in run
    self.env[node] = self.run_node(node)
  File "/scratch/anijain/work/torchdynamo/torchinductor/graph.py", line 314, in run_node
    result = super().run_node(n)
  File "/scratch/anijain/work/pytorch/torch/fx/interpreter.py", line 171, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
  File "/scratch/anijain/work/torchdynamo/torchinductor/graph.py", line 259, in call_function
    raise LoweringException(e, target, args, kwargs) from e
torchinductor.exc.LoweringException: AssertionError:
  target: aten.cat.default
  args[0]: [TensorBox(StorageBox(
    InputBuffer(name='arg0_1', layout=FixedLayout('cuda', torch.float32, size=[s0, s1, s2, s2], stride=[s1*s2**2, s2**2, s2, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='arg1_1', layout=FixedLayout('cuda', torch.float32, size=[s0, s1, s2, s2], stride=[s1*s2**2, s2**2, s2, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='arg2_1', layout=FixedLayout('cuda', torch.float16, size=[s0, s0, s2, s2], stride=[s3, s2**2, s2, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='arg3_1', layout=FixedLayout('cuda', torch.float16, size=[s0, s0, s2, s2], stride=[s3, s2**2, s2, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='arg5_1', layout=FixedLayout('cuda', torch.float32, size=[s0, s0, s2, s2], stride=[s0*s2**2, s2**2, s2, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='arg4_1', layout=FixedLayout('cuda', torch.float32, size=[s0, s0, s2, s2], stride=[s0*s2**2, s2**2, s2, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='arg7_1', layout=FixedLayout('cuda', torch.float32, size=[s0, s1, s2, s2], stride=[s1*s2**2, s2**2, s2, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='arg6_1', layout=FixedLayout('cuda', torch.float32, size=[s0, s1, s2, s2], stride=[s1*s2**2, s2**2, s2, 1]))
  ))]
  args[1]: 1

While executing %cat : [#users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%arg0_1, %arg1_1, %arg2_1, %arg3_1, %arg5_1, %arg4_1, %arg7_1, %arg6_1], 1), kwargs = {})
Original traceback:
None
Chillee commented 2 years ago

This might be resolved by https://github.com/pytorch/torchdynamo/pull/1614?

anijain2305 commented 2 years ago

I have this commit in my branch. So, maybe some corner case is still missing.