pytorch / torchdynamo

A Python-level JIT compiler designed to make unmodified PyTorch programs faster.
BSD 3-Clause "New" or "Revised" License
993 stars 123 forks source link

[inductor] CUBLAS_STATUS_INTERNAL_ERROR: swin_base_patch4_window7_224 #1439

Closed anijain2305 closed 1 year ago

anijain2305 commented 1 year ago

Repro

benchmarks/timm_models.py -dcuda --inductor --float32 --train --accuracy --only swin_base_patch4_window7_224

IMA still happens

WARNING > make_fallback(aten.col2im): a decomposition exists, we should switch to it
cuda train swin_base_patch4_window7_224       torchdynamo.symbolic_convert: [WARNING] Graph break: call_function in skip_files /scratch/anijain/work/torchdynamo/torchdynamo/utils.py from user code at   File "/scratch/anijain/work/torchdynamo/benchmarks/timm_models.py", line 318, in forward_and_backward_pass
    cloned_inputs = clone_inputs(inputs)

torchdynamo.symbolic_convert: [WARNING] Graph break: call_method NNModuleVariable() zero_grad [ConstantVariable(bool)] {} from user code at   File "/scratch/anijain/work/torchdynamo/benchmarks/timm_models.py", line 319, in <graph break in forward_and_backward_pass>
    mod.zero_grad(True)

TorchDynamo optimized model failed to run because of following error
ERROR > CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Traceback (most recent call last):
  File "/scratch/anijain/work/torchdynamo/benchmarks/common.py", line 1134, in check_accuracy
    new_result = optimized_model_iter_fn(model, example_inputs)
  File "/scratch/anijain/work/torchdynamo/torchdynamo/eval_frame.py", line 166, in _fn
    return fn(*args, **kwargs)
  File "/scratch/anijain/work/torchdynamo/benchmarks/timm_models.py", line 317, in forward_and_backward_pass
    def forward_and_backward_pass(self, mod, inputs, collect_outputs=True):
  File "/scratch/anijain/work/torchdynamo/benchmarks/timm_models.py", line 318, in <graph break in forward_and_backward_pass>
    cloned_inputs = clone_inputs(inputs)
  File "/scratch/anijain/work/torchdynamo/benchmarks/timm_models.py", line 327, in <graph break in forward_and_backward_pass>
    return collect_results(mod, pred, loss, cloned_inputs)
  File "/scratch/anijain/work/torchdynamo/torchdynamo/testing.py", line 41, in collect_results
    if isinstance(loss, torch.Tensor) and loss.item() > 1:
RuntimeError: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
FAIL
anijain2305 commented 1 year ago

Repro - Trying to minify it further


import torch
from torch import tensor, device
import torch.fx as fx
from torchdynamo.testing import rand_strided
from math import inf
from torch.fx.experimental.proxy_tensor import make_fx

# torch version: 1.13.0a0+git071f875
# torch cuda version: 11.6
# torch git version: 071f875046202b87213865dfc180abdf8368f116

# CUDA Info:
# nvcc: NVIDIA (R) Cuda compiler driver
# Copyright (c) 2005-2022 NVIDIA Corporation
# Built on Thu_Feb_10_18:23:41_PST_2022
# Cuda compilation tools, release 11.6, V11.6.112
# Build cuda_11.6.r11.6/compiler.30978841_0

# GPU Hardware Info:
# NVIDIA A100-SXM4-40GB : 8

from torch.nn import *
class Repro(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, arg8_1, arg10_1, arg113_1, arg156_1, arg508_1, arg509_1, arg510_1, arg515_1, arg540_1, arg541_1, mul_7, view_4, mul_9, view_13, add_75, view_650, view_531, permute_342, view_544, view_557, bmm_74, mul_373, bmm_75, mul_448):
        cat_17 = torch.ops.aten.cat.default([mul_373, permute_342, view_531]);  mul_373 = permute_342 = view_531 = None
        view_537 = torch.ops.aten.view.default(cat_17, [3, 8, 16, 49, 32]);  cat_17 = None
        permute_343 = torch.ops.aten.permute.default(view_537, [1, 3, 0, 2, 4]);  view_537 = None
        mm_148 = torch.ops.aten.mm.default(view_544, arg508_1);  arg508_1 = None
        permute_348 = torch.ops.aten.permute.default(view_544, [1, 0])
        mm_149 = torch.ops.aten.mm.default(permute_348, arg156_1);  permute_348 = arg156_1 = None
        permute_349 = torch.ops.aten.permute.default(mm_149, [1, 0]);  mm_149 = None
        sum_262 = torch.ops.aten.sum.dim_IntList(view_544, [0], True);  view_544 = None
        view_545 = torch.ops.aten.view.default(sum_262, [512]);  sum_262 = None
        view_546 = torch.ops.aten.view.default(mm_148, [2, 196, 2048]);  mm_148 = None
        permute_350 = torch.ops.aten.permute.default(permute_349, [1, 0]);  permute_349 = None
        mul_380 = torch.ops.aten.mul.Tensor(view_546, arg509_1);  view_546 = arg509_1 = None
        view_547 = torch.ops.aten.view.default(mul_380, [392, 2048]);  mul_380 = None
        mm_150 = torch.ops.aten.mm.default(view_547, arg510_1);  arg510_1 = None
        permute_351 = torch.ops.aten.permute.default(view_547, [1, 0])
        mm_151 = torch.ops.aten.mm.default(permute_351, view_13);  permute_351 = view_13 = None
        permute_352 = torch.ops.aten.permute.default(mm_151, [1, 0]);  mm_151 = None
        sum_263 = torch.ops.aten.sum.dim_IntList(view_547, [0], True);  view_547 = None
        view_548 = torch.ops.aten.view.default(sum_263, [2048]);  sum_263 = None
        view_549 = torch.ops.aten.view.default(mm_150, [2, 196, 512]);  mm_150 = None
        permute_353 = torch.ops.aten.permute.default(permute_352, [1, 0]);  permute_352 = None
        view_563 = torch.ops.aten.view.default(bmm_74, [8, 16, 32, 49]);  bmm_74 = None
        view_564 = torch.ops.aten.view.default(bmm_75, [8, 16, 49, 32]);  bmm_75 = None
        permute_361 = torch.ops.aten.permute.default(view_563, [0, 1, 3, 2]);  view_563 = None
        mul_389 = torch.ops.aten.mul.Tensor(view_564, 0.1767766952966369);  view_564 = None
        cat_18 = torch.ops.aten.cat.default([mul_389, permute_361, view_557]);  mul_389 = permute_361 = view_557 = None
        view_565 = torch.ops.aten.view.default(cat_18, [3, 8, 16, 49, 32]);  cat_18 = None
        permute_362 = torch.ops.aten.permute.default(view_565, [1, 3, 0, 2, 4]);  view_565 = None
        clone_70 = torch.ops.aten.clone.default(permute_362, memory_format = torch.contiguous_format);  permute_362 = None
        _unsafe_view_70 = torch.ops.aten._unsafe_view.default(clone_70, [8, 49, 1536]);  clone_70 = None
        view_566 = torch.ops.aten.view.default(_unsafe_view_70, [392, 1536]);  _unsafe_view_70 = None
        mm_154 = torch.ops.aten.mm.default(view_566, arg515_1);  view_566 = arg515_1 = None
        mm_180 = torch.ops.aten.mm.default(view_650, arg540_1);  arg540_1 = None
        permute_423 = torch.ops.aten.permute.default(view_650, [1, 0])
        mm_181 = torch.ops.aten.mm.default(permute_423, arg113_1);  permute_423 = arg113_1 = None
        permute_424 = torch.ops.aten.permute.default(mm_181, [1, 0]);  mm_181 = None
        sum_317 = torch.ops.aten.sum.dim_IntList(view_650, [0], True);  view_650 = None
        view_651 = torch.ops.aten.view.default(sum_317, [768]);  sum_317 = None
        view_652 = torch.ops.aten.view.default(mm_180, [32, 49, 256]);  mm_180 = None
        permute_425 = torch.ops.aten.permute.default(permute_424, [1, 0]);  permute_424 = None
        view_653 = torch.ops.aten.view.default(view_652, [32, 7, 7, 256]);  view_652 = None
        view_654 = torch.ops.aten.view.default(view_653, [2, 4, 4, 7, 7, 256]);  view_653 = None
        permute_426 = torch.ops.aten.permute.default(view_654, [0, 1, 3, 2, 4, 5]);  view_654 = None
        clone_83 = torch.ops.aten.clone.default(permute_426, memory_format = torch.contiguous_format);  permute_426 = None
        _unsafe_view_83 = torch.ops.aten._unsafe_view.default(clone_83, [2, 28, 28, 256]);  clone_83 = None
        view_655 = torch.ops.aten.view.default(_unsafe_view_83, [2, 784, 256]);  _unsafe_view_83 = None
        mul_444 = torch.ops.aten.mul.Tensor(view_655, arg10_1);  arg10_1 = None
        mul_445 = torch.ops.aten.mul.Tensor(mul_444, 256);  mul_444 = None
        mul_449 = torch.ops.aten.mul.Tensor(view_655, mul_9);  mul_9 = None
        sum_320 = torch.ops.aten.sum.dim_IntList(mul_449, [0, 1]);  mul_449 = None
        sum_321 = torch.ops.aten.sum.dim_IntList(view_655, [0, 1]);  view_655 = None
        add_76 = torch.ops.aten.add.Tensor(add_75, mul_448);  add_75 = mul_448 = None
        view_656 = torch.ops.aten.view.default(add_76, [1568, 256]);  add_76 = None
        permute_427 = torch.ops.aten.permute.default(view_656, [1, 0])
        mm_182 = torch.ops.aten.mm.default(permute_427, view_4);  permute_427 = view_4 = None
        permute_428 = torch.ops.aten.permute.default(mm_182, [1, 0]);  mm_182 = None
        mm_183 = torch.ops.aten.mm.default(view_656, arg541_1);  view_656 = arg541_1 = None
        view_657 = torch.ops.aten.view.default(mm_183, [2, 784, 512]);  mm_183 = None
        permute_429 = torch.ops.aten.permute.default(permute_428, [1, 0]);  permute_428 = None
        mul_450 = torch.ops.aten.mul.Tensor(view_657, arg8_1);  view_657 = arg8_1 = None
        mul_451 = torch.ops.aten.mul.Tensor(mul_450, 512)
        sum_322 = torch.ops.aten.sum.dim_IntList(mul_450, [2], True)
        mul_452 = torch.ops.aten.mul.Tensor(mul_450, mul_7);  mul_450 = mul_7 = None
        sum_323 = torch.ops.aten.sum.dim_IntList(mul_452, [2], True);  mul_452 = None
        return [view_545, permute_350, view_548, permute_353, view_651, permute_425, sum_320, sum_321, permute_429, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None]

args = [((512,), (1,), torch.float32, 'cuda'), ((256,), (1,), torch.float32, 'cuda'), ((1568, 256), (256, 1), torch.float32, 'cuda'), ((392, 2048), (2048, 1), torch.float32, 'cuda'), ((512, 2048), (2048, 1), torch.float32, 'cuda'), ((2, 196, 2048), (401408, 2048, 1), torch.float32, 'cuda'), ((2048, 512), (512, 1), torch.float32, 'cuda'), ((1536, 512), (512, 1), torch.float32, 'cuda'), ((768, 256), (256, 1), torch.float32, 'cuda'), ((256, 512), (512, 1), torch.float32, 'cuda'), ((2, 784, 512), (401408, 512, 1), torch.float32, 'cuda'), ((1568, 512), (512, 1), torch.float32, 'cuda'), ((2, 784, 256), (200704, 256, 1), torch.float32, 'cuda'), ((392, 512), (512, 1), torch.float32, 'cuda'), ((2, 784, 256), (200704, 256, 1), torch.float32, 'cuda'), ((1568, 768), (768, 1), torch.float32, 'cuda'), ((8, 16, 49, 32), (25088, 1568, 32, 1), torch.float32, 'cuda'), ((8, 16, 49, 32), (25088, 1568, 1, 49), torch.float32, 'cuda'), ((392, 512), (512, 1), torch.float32, 'cuda'), ((8, 16, 49, 32), (25088, 1568, 32, 1), torch.float32, 'cuda'), ((128, 32, 49), (1568, 49, 1), torch.float32, 'cuda'), ((8, 16, 49, 32), (25088, 1568, 32, 1), torch.float32, 'cuda'), ((128, 49, 32), (1568, 32, 1), torch.float32, 'cuda'), ((2, 784, 256), (200704, 256, 1), torch.float32, 'cuda')]
args = [rand_strided(shape, stride, dtype, device) for shape, stride, dtype, device in args]
mod = make_fx(Repro())(*args)

from torchinductor.compile_fx import compile_fx_inner

compiled = compile_fx_inner(mod, args)
compiled(*args)
anijain2305 commented 1 year ago

Smallest minifier could do


import torch
from torch import tensor, device
import torch.fx as fx
from torchdynamo.testing import rand_strided
from math import inf
from torch.fx.experimental.proxy_tensor import make_fx

# torch version: 1.13.0a0+git071f875
# torch cuda version: 11.6
# torch git version: 071f875046202b87213865dfc180abdf8368f116

# CUDA Info:
# nvcc: NVIDIA (R) Cuda compiler driver
# Copyright (c) 2005-2022 NVIDIA Corporation
# Built on Thu_Feb_10_18:23:41_PST_2022
# Cuda compilation tools, release 11.6, V11.6.112
# Build cuda_11.6.r11.6/compiler.30978841_0

# GPU Hardware Info:
# NVIDIA A100-SXM4-40GB : 8

from torch.nn import *
class Repro(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, arg0_1, arg3_1, arg4_1, arg5_1, arg8_1, arg9_1, arg13_1, arg15_1, arg16_1, arg17_1, arg18_1, arg19_1, arg21_1, permute_7, mul_1, view_15, permute_1, mm_2):
        cat = torch.ops.aten.cat.default([arg21_1, arg17_1, arg16_1]);  arg21_1 = arg17_1 = arg16_1 = None
        mm = torch.ops.aten.mm.default(arg18_1, arg4_1);  arg18_1 = arg4_1 = None
        mm_1 = torch.ops.aten.mm.default(permute_1, arg3_1);  permute_1 = arg3_1 = None
        permute_2 = torch.ops.aten.permute.default(mm_1, [1, 0]);  mm_1 = None
        view_2 = torch.ops.aten.view.default(mm, [2, 196, 2048]);  mm = None
        permute_3 = torch.ops.aten.permute.default(permute_2, [1, 0]);  permute_2 = None
        mul = torch.ops.aten.mul.Tensor(view_2, arg5_1);  view_2 = arg5_1 = None
        view_3 = torch.ops.aten.view.default(mul, [392, 2048]);  mul = None
        permute_4 = torch.ops.aten.permute.default(view_3, [1, 0]);  view_3 = None
        mm_3 = torch.ops.aten.mm.default(permute_4, arg13_1);  permute_4 = arg13_1 = None
        permute_5 = torch.ops.aten.permute.default(mm_3, [1, 0]);  mm_3 = None
        view_5 = torch.ops.aten.view.default(mm_2, [2, 196, 512]);  mm_2 = None
        permute_6 = torch.ops.aten.permute.default(permute_5, [1, 0]);  permute_5 = None
        cat_1 = torch.ops.aten.cat.default([mul_1, permute_7, arg19_1]);  mul_1 = permute_7 = arg19_1 = None
        mm_5 = torch.ops.aten.mm.default(arg15_1, arg8_1);  arg15_1 = arg8_1 = None
        view_11 = torch.ops.aten.view.default(mm_5, [32, 49, 256]);  mm_5 = None
        view_12 = torch.ops.aten.view.default(view_11, [32, 7, 7, 256]);  view_11 = None
        view_13 = torch.ops.aten.view.default(view_12, [2, 4, 4, 7, 7, 256]);  view_12 = None
        permute_12 = torch.ops.aten.permute.default(view_13, [0, 1, 3, 2, 4, 5]);  view_13 = None
        clone_1 = torch.ops.aten.clone.default(permute_12, memory_format = torch.contiguous_format);  permute_12 = None
        _unsafe_view_1 = torch.ops.aten._unsafe_view.default(clone_1, [2, 28, 28, 256]);  clone_1 = None
        view_14 = torch.ops.aten.view.default(_unsafe_view_1, [2, 784, 256]);  _unsafe_view_1 = None
        sum_5 = torch.ops.aten.sum.dim_IntList(view_14, [0, 1]);  view_14 = None
        mm_8 = torch.ops.aten.mm.default(view_15, arg9_1);  view_15 = arg9_1 = None
        view_16 = torch.ops.aten.view.default(mm_8, [2, 784, 512]);  mm_8 = None
        mul_5 = torch.ops.aten.mul.Tensor(view_16, arg0_1);  view_16 = arg0_1 = None
        sum_6 = torch.ops.aten.sum.dim_IntList(mul_5, [2], True);  mul_5 = None
        return [permute_3, permute_6, sum_5]

args = [((512,), (1,), torch.float32, 'cuda'), ((392, 2048), (2048, 1), torch.float32, 'cuda'), ((512, 2048), (2048, 1), torch.float32, 'cuda'), ((2, 196, 2048), (401408, 2048, 1), torch.float32, 'cuda'), ((768, 256), (256, 1), torch.float32, 'cuda'), ((256, 512), (512, 1), torch.float32, 'cuda'), ((392, 512), (512, 1), torch.float32, 'cuda'), ((1568, 768), (768, 1), torch.float32, 'cuda'), ((8, 16, 49, 32), (25088, 1568, 32, 1), torch.float32, 'cuda'), ((8, 16, 49, 32), (25088, 1568, 1, 49), torch.float32, 'cuda'), ((392, 512), (512, 1), torch.float32, 'cuda'), ((8, 16, 49, 32), (25088, 1568, 32, 1), torch.float32, 'cuda'), ((8, 16, 49, 32), (25088, 1568, 32, 1), torch.float32, 'cuda'), ((8, 16, 49, 32), (25088, 1568, 1, 49), torch.float32, 'cuda'), ((8, 16, 49, 32), (25088, 1568, 32, 1), torch.float32, 'cuda'), ((1568, 256), (256, 1), torch.float32, 'cuda'), ((512, 392), (1, 512), torch.float32, 'cuda'), ((392, 512), (512, 1), torch.float32, 'cuda')]
args = [rand_strided(shape, stride, dtype, device) for shape, stride, dtype, device in args]
mod = make_fx(Repro())(*args)

from torchinductor.compile_fx import compile_fx_inner

compiled = compile_fx_inner(mod, args)
compiled(*args)
ngimel commented 1 year ago

Thanks @anijain2305 these minified repros are very helpful!