Closed anijain2305 closed 1 year ago
Repro - Trying to minify it further
import torch
from torch import tensor, device
import torch.fx as fx
from torchdynamo.testing import rand_strided
from math import inf
from torch.fx.experimental.proxy_tensor import make_fx
# torch version: 1.13.0a0+git071f875
# torch cuda version: 11.6
# torch git version: 071f875046202b87213865dfc180abdf8368f116
# CUDA Info:
# nvcc: NVIDIA (R) Cuda compiler driver
# Copyright (c) 2005-2022 NVIDIA Corporation
# Built on Thu_Feb_10_18:23:41_PST_2022
# Cuda compilation tools, release 11.6, V11.6.112
# Build cuda_11.6.r11.6/compiler.30978841_0
# GPU Hardware Info:
# NVIDIA A100-SXM4-40GB : 8
from torch.nn import *
class Repro(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, arg8_1, arg10_1, arg113_1, arg156_1, arg508_1, arg509_1, arg510_1, arg515_1, arg540_1, arg541_1, mul_7, view_4, mul_9, view_13, add_75, view_650, view_531, permute_342, view_544, view_557, bmm_74, mul_373, bmm_75, mul_448):
cat_17 = torch.ops.aten.cat.default([mul_373, permute_342, view_531]); mul_373 = permute_342 = view_531 = None
view_537 = torch.ops.aten.view.default(cat_17, [3, 8, 16, 49, 32]); cat_17 = None
permute_343 = torch.ops.aten.permute.default(view_537, [1, 3, 0, 2, 4]); view_537 = None
mm_148 = torch.ops.aten.mm.default(view_544, arg508_1); arg508_1 = None
permute_348 = torch.ops.aten.permute.default(view_544, [1, 0])
mm_149 = torch.ops.aten.mm.default(permute_348, arg156_1); permute_348 = arg156_1 = None
permute_349 = torch.ops.aten.permute.default(mm_149, [1, 0]); mm_149 = None
sum_262 = torch.ops.aten.sum.dim_IntList(view_544, [0], True); view_544 = None
view_545 = torch.ops.aten.view.default(sum_262, [512]); sum_262 = None
view_546 = torch.ops.aten.view.default(mm_148, [2, 196, 2048]); mm_148 = None
permute_350 = torch.ops.aten.permute.default(permute_349, [1, 0]); permute_349 = None
mul_380 = torch.ops.aten.mul.Tensor(view_546, arg509_1); view_546 = arg509_1 = None
view_547 = torch.ops.aten.view.default(mul_380, [392, 2048]); mul_380 = None
mm_150 = torch.ops.aten.mm.default(view_547, arg510_1); arg510_1 = None
permute_351 = torch.ops.aten.permute.default(view_547, [1, 0])
mm_151 = torch.ops.aten.mm.default(permute_351, view_13); permute_351 = view_13 = None
permute_352 = torch.ops.aten.permute.default(mm_151, [1, 0]); mm_151 = None
sum_263 = torch.ops.aten.sum.dim_IntList(view_547, [0], True); view_547 = None
view_548 = torch.ops.aten.view.default(sum_263, [2048]); sum_263 = None
view_549 = torch.ops.aten.view.default(mm_150, [2, 196, 512]); mm_150 = None
permute_353 = torch.ops.aten.permute.default(permute_352, [1, 0]); permute_352 = None
view_563 = torch.ops.aten.view.default(bmm_74, [8, 16, 32, 49]); bmm_74 = None
view_564 = torch.ops.aten.view.default(bmm_75, [8, 16, 49, 32]); bmm_75 = None
permute_361 = torch.ops.aten.permute.default(view_563, [0, 1, 3, 2]); view_563 = None
mul_389 = torch.ops.aten.mul.Tensor(view_564, 0.1767766952966369); view_564 = None
cat_18 = torch.ops.aten.cat.default([mul_389, permute_361, view_557]); mul_389 = permute_361 = view_557 = None
view_565 = torch.ops.aten.view.default(cat_18, [3, 8, 16, 49, 32]); cat_18 = None
permute_362 = torch.ops.aten.permute.default(view_565, [1, 3, 0, 2, 4]); view_565 = None
clone_70 = torch.ops.aten.clone.default(permute_362, memory_format = torch.contiguous_format); permute_362 = None
_unsafe_view_70 = torch.ops.aten._unsafe_view.default(clone_70, [8, 49, 1536]); clone_70 = None
view_566 = torch.ops.aten.view.default(_unsafe_view_70, [392, 1536]); _unsafe_view_70 = None
mm_154 = torch.ops.aten.mm.default(view_566, arg515_1); view_566 = arg515_1 = None
mm_180 = torch.ops.aten.mm.default(view_650, arg540_1); arg540_1 = None
permute_423 = torch.ops.aten.permute.default(view_650, [1, 0])
mm_181 = torch.ops.aten.mm.default(permute_423, arg113_1); permute_423 = arg113_1 = None
permute_424 = torch.ops.aten.permute.default(mm_181, [1, 0]); mm_181 = None
sum_317 = torch.ops.aten.sum.dim_IntList(view_650, [0], True); view_650 = None
view_651 = torch.ops.aten.view.default(sum_317, [768]); sum_317 = None
view_652 = torch.ops.aten.view.default(mm_180, [32, 49, 256]); mm_180 = None
permute_425 = torch.ops.aten.permute.default(permute_424, [1, 0]); permute_424 = None
view_653 = torch.ops.aten.view.default(view_652, [32, 7, 7, 256]); view_652 = None
view_654 = torch.ops.aten.view.default(view_653, [2, 4, 4, 7, 7, 256]); view_653 = None
permute_426 = torch.ops.aten.permute.default(view_654, [0, 1, 3, 2, 4, 5]); view_654 = None
clone_83 = torch.ops.aten.clone.default(permute_426, memory_format = torch.contiguous_format); permute_426 = None
_unsafe_view_83 = torch.ops.aten._unsafe_view.default(clone_83, [2, 28, 28, 256]); clone_83 = None
view_655 = torch.ops.aten.view.default(_unsafe_view_83, [2, 784, 256]); _unsafe_view_83 = None
mul_444 = torch.ops.aten.mul.Tensor(view_655, arg10_1); arg10_1 = None
mul_445 = torch.ops.aten.mul.Tensor(mul_444, 256); mul_444 = None
mul_449 = torch.ops.aten.mul.Tensor(view_655, mul_9); mul_9 = None
sum_320 = torch.ops.aten.sum.dim_IntList(mul_449, [0, 1]); mul_449 = None
sum_321 = torch.ops.aten.sum.dim_IntList(view_655, [0, 1]); view_655 = None
add_76 = torch.ops.aten.add.Tensor(add_75, mul_448); add_75 = mul_448 = None
view_656 = torch.ops.aten.view.default(add_76, [1568, 256]); add_76 = None
permute_427 = torch.ops.aten.permute.default(view_656, [1, 0])
mm_182 = torch.ops.aten.mm.default(permute_427, view_4); permute_427 = view_4 = None
permute_428 = torch.ops.aten.permute.default(mm_182, [1, 0]); mm_182 = None
mm_183 = torch.ops.aten.mm.default(view_656, arg541_1); view_656 = arg541_1 = None
view_657 = torch.ops.aten.view.default(mm_183, [2, 784, 512]); mm_183 = None
permute_429 = torch.ops.aten.permute.default(permute_428, [1, 0]); permute_428 = None
mul_450 = torch.ops.aten.mul.Tensor(view_657, arg8_1); view_657 = arg8_1 = None
mul_451 = torch.ops.aten.mul.Tensor(mul_450, 512)
sum_322 = torch.ops.aten.sum.dim_IntList(mul_450, [2], True)
mul_452 = torch.ops.aten.mul.Tensor(mul_450, mul_7); mul_450 = mul_7 = None
sum_323 = torch.ops.aten.sum.dim_IntList(mul_452, [2], True); mul_452 = None
return [view_545, permute_350, view_548, permute_353, view_651, permute_425, sum_320, sum_321, permute_429, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None]
args = [((512,), (1,), torch.float32, 'cuda'), ((256,), (1,), torch.float32, 'cuda'), ((1568, 256), (256, 1), torch.float32, 'cuda'), ((392, 2048), (2048, 1), torch.float32, 'cuda'), ((512, 2048), (2048, 1), torch.float32, 'cuda'), ((2, 196, 2048), (401408, 2048, 1), torch.float32, 'cuda'), ((2048, 512), (512, 1), torch.float32, 'cuda'), ((1536, 512), (512, 1), torch.float32, 'cuda'), ((768, 256), (256, 1), torch.float32, 'cuda'), ((256, 512), (512, 1), torch.float32, 'cuda'), ((2, 784, 512), (401408, 512, 1), torch.float32, 'cuda'), ((1568, 512), (512, 1), torch.float32, 'cuda'), ((2, 784, 256), (200704, 256, 1), torch.float32, 'cuda'), ((392, 512), (512, 1), torch.float32, 'cuda'), ((2, 784, 256), (200704, 256, 1), torch.float32, 'cuda'), ((1568, 768), (768, 1), torch.float32, 'cuda'), ((8, 16, 49, 32), (25088, 1568, 32, 1), torch.float32, 'cuda'), ((8, 16, 49, 32), (25088, 1568, 1, 49), torch.float32, 'cuda'), ((392, 512), (512, 1), torch.float32, 'cuda'), ((8, 16, 49, 32), (25088, 1568, 32, 1), torch.float32, 'cuda'), ((128, 32, 49), (1568, 49, 1), torch.float32, 'cuda'), ((8, 16, 49, 32), (25088, 1568, 32, 1), torch.float32, 'cuda'), ((128, 49, 32), (1568, 32, 1), torch.float32, 'cuda'), ((2, 784, 256), (200704, 256, 1), torch.float32, 'cuda')]
args = [rand_strided(shape, stride, dtype, device) for shape, stride, dtype, device in args]
mod = make_fx(Repro())(*args)
from torchinductor.compile_fx import compile_fx_inner
compiled = compile_fx_inner(mod, args)
compiled(*args)
Smallest minifier could do
import torch
from torch import tensor, device
import torch.fx as fx
from torchdynamo.testing import rand_strided
from math import inf
from torch.fx.experimental.proxy_tensor import make_fx
# torch version: 1.13.0a0+git071f875
# torch cuda version: 11.6
# torch git version: 071f875046202b87213865dfc180abdf8368f116
# CUDA Info:
# nvcc: NVIDIA (R) Cuda compiler driver
# Copyright (c) 2005-2022 NVIDIA Corporation
# Built on Thu_Feb_10_18:23:41_PST_2022
# Cuda compilation tools, release 11.6, V11.6.112
# Build cuda_11.6.r11.6/compiler.30978841_0
# GPU Hardware Info:
# NVIDIA A100-SXM4-40GB : 8
from torch.nn import *
class Repro(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, arg0_1, arg3_1, arg4_1, arg5_1, arg8_1, arg9_1, arg13_1, arg15_1, arg16_1, arg17_1, arg18_1, arg19_1, arg21_1, permute_7, mul_1, view_15, permute_1, mm_2):
cat = torch.ops.aten.cat.default([arg21_1, arg17_1, arg16_1]); arg21_1 = arg17_1 = arg16_1 = None
mm = torch.ops.aten.mm.default(arg18_1, arg4_1); arg18_1 = arg4_1 = None
mm_1 = torch.ops.aten.mm.default(permute_1, arg3_1); permute_1 = arg3_1 = None
permute_2 = torch.ops.aten.permute.default(mm_1, [1, 0]); mm_1 = None
view_2 = torch.ops.aten.view.default(mm, [2, 196, 2048]); mm = None
permute_3 = torch.ops.aten.permute.default(permute_2, [1, 0]); permute_2 = None
mul = torch.ops.aten.mul.Tensor(view_2, arg5_1); view_2 = arg5_1 = None
view_3 = torch.ops.aten.view.default(mul, [392, 2048]); mul = None
permute_4 = torch.ops.aten.permute.default(view_3, [1, 0]); view_3 = None
mm_3 = torch.ops.aten.mm.default(permute_4, arg13_1); permute_4 = arg13_1 = None
permute_5 = torch.ops.aten.permute.default(mm_3, [1, 0]); mm_3 = None
view_5 = torch.ops.aten.view.default(mm_2, [2, 196, 512]); mm_2 = None
permute_6 = torch.ops.aten.permute.default(permute_5, [1, 0]); permute_5 = None
cat_1 = torch.ops.aten.cat.default([mul_1, permute_7, arg19_1]); mul_1 = permute_7 = arg19_1 = None
mm_5 = torch.ops.aten.mm.default(arg15_1, arg8_1); arg15_1 = arg8_1 = None
view_11 = torch.ops.aten.view.default(mm_5, [32, 49, 256]); mm_5 = None
view_12 = torch.ops.aten.view.default(view_11, [32, 7, 7, 256]); view_11 = None
view_13 = torch.ops.aten.view.default(view_12, [2, 4, 4, 7, 7, 256]); view_12 = None
permute_12 = torch.ops.aten.permute.default(view_13, [0, 1, 3, 2, 4, 5]); view_13 = None
clone_1 = torch.ops.aten.clone.default(permute_12, memory_format = torch.contiguous_format); permute_12 = None
_unsafe_view_1 = torch.ops.aten._unsafe_view.default(clone_1, [2, 28, 28, 256]); clone_1 = None
view_14 = torch.ops.aten.view.default(_unsafe_view_1, [2, 784, 256]); _unsafe_view_1 = None
sum_5 = torch.ops.aten.sum.dim_IntList(view_14, [0, 1]); view_14 = None
mm_8 = torch.ops.aten.mm.default(view_15, arg9_1); view_15 = arg9_1 = None
view_16 = torch.ops.aten.view.default(mm_8, [2, 784, 512]); mm_8 = None
mul_5 = torch.ops.aten.mul.Tensor(view_16, arg0_1); view_16 = arg0_1 = None
sum_6 = torch.ops.aten.sum.dim_IntList(mul_5, [2], True); mul_5 = None
return [permute_3, permute_6, sum_5]
args = [((512,), (1,), torch.float32, 'cuda'), ((392, 2048), (2048, 1), torch.float32, 'cuda'), ((512, 2048), (2048, 1), torch.float32, 'cuda'), ((2, 196, 2048), (401408, 2048, 1), torch.float32, 'cuda'), ((768, 256), (256, 1), torch.float32, 'cuda'), ((256, 512), (512, 1), torch.float32, 'cuda'), ((392, 512), (512, 1), torch.float32, 'cuda'), ((1568, 768), (768, 1), torch.float32, 'cuda'), ((8, 16, 49, 32), (25088, 1568, 32, 1), torch.float32, 'cuda'), ((8, 16, 49, 32), (25088, 1568, 1, 49), torch.float32, 'cuda'), ((392, 512), (512, 1), torch.float32, 'cuda'), ((8, 16, 49, 32), (25088, 1568, 32, 1), torch.float32, 'cuda'), ((8, 16, 49, 32), (25088, 1568, 32, 1), torch.float32, 'cuda'), ((8, 16, 49, 32), (25088, 1568, 1, 49), torch.float32, 'cuda'), ((8, 16, 49, 32), (25088, 1568, 32, 1), torch.float32, 'cuda'), ((1568, 256), (256, 1), torch.float32, 'cuda'), ((512, 392), (1, 512), torch.float32, 'cuda'), ((392, 512), (512, 1), torch.float32, 'cuda')]
args = [rand_strided(shape, stride, dtype, device) for shape, stride, dtype, device in args]
mod = make_fx(Repro())(*args)
from torchinductor.compile_fx import compile_fx_inner
compiled = compile_fx_inner(mod, args)
compiled(*args)
Thanks @anijain2305 these minified repros are very helpful!
Repro
IMA still happens