pytorch / torchdynamo

A Python-level JIT compiler designed to make unmodified PyTorch programs faster.
BSD 3-Clause "New" or "Revised" License
1.01k stars 124 forks source link

MMCLS Resnet-50 failed training, got 'Triton Error [CUDA]: invalid device context' #2026

Closed C1rN09 closed 1 year ago

C1rN09 commented 1 year ago

🐛 Describe the bug

Training MMCLS resnet-50 model with inductor backend got RuntimeError: Triton Error [CUDA]: invalid device context during backward.

PyTorch version: 2.0.0.dev20230109+cu116

Error logs

image

Minified repro

Minifier failed to reproduce the problem. Have tried to locate the function and below are steps to reproduce.

  1. Install dependencies of OpenMMLab
ninja
git+https://github.com/open-mmlab/mmengine@experimental/compile
git+https://github.com/open-mmlab/mmcv@2.x
  1. Install MMCLS from source
git clone https://github.com/open-mmlab/mmclassification.git
cd mmclassification
git checkout -b 1.x origin/1.x
pip install -r requirements.txt
pip install -e .
  1. Launch training
# Seems like errors come from this function
export TORCHDYNAMO_DEBUG_FUNCTION='update_params'
python tools/train.py configs/resnet/resnet50_8xb32_in1k.py --cfg-options compile.target='train_step'
ngimel commented 1 year ago

If you try running the temporary file that's throwing the error standalone, does it throw the error? python /tmp/torchinductor_zhaoqian/ch5cogc...py. If it throws the error, can you please share this file?

C1rN09 commented 1 year ago

No, it doesn't throw any error. Should I paste the content of the temporary file in this issue?

C1rN09 commented 1 year ago

I paste the minifier_launcher.py obtained through TORCHDYNAMO_REPRO_AFTER='aot' below. However, running it via python gives RuntimeError: Input graph did not fail the tester

minifier_launcher.py ```python isolate_fails_code_str = None import torch from torch import tensor, device import torch.fx as fx from torch._dynamo.testing import rand_strided from math import inf from torch.fx.experimental.proxy_tensor import make_fx import torch._dynamo.config import torch._inductor.config torch._dynamo.config.load_config(b'\x80\x04\x95\x13\x08\x00\x00\x00\x00\x00\x00}\x94(\x8c\x08__name__\x94\x8c\x14torch._dynamo.config\x94\x8c\x07__doc__\x94N\x8c\x0b__package__\x94\x8c\rtorch._dynamo\x94\x8c\n__loader__\x94\x8c\x1a_frozen_importlib_external\x94\x8c\x10SourceFileLoader\x94\x93\x94)\x81\x94}\x94(\x8c\x04name\x94h\x02\x8c\x04path\x94\x8cc/nvme/data/zhaoqian/miniconda3/envs/dynamo-test/lib/python3.8/site-packages/torch/_dynamo/config.py\x94ub\x8c\x08__spec__\x94\x8c\x11_frozen_importlib\x94\x8c\nModuleSpec\x94\x93\x94)\x81\x94}\x94(h\x0ch\x02\x8c\x06loader\x94h\n\x8c\x06origin\x94h\x0e\x8c\x0cloader_state\x94N\x8c\x1asubmodule_search_locations\x94N\x8c\r_set_fileattr\x94\x88\x8c\x07_cached\x94\x8c{/nvme/data/zhaoqian/miniconda3/envs/dynamo-test/lib/python3.8/site-packages/torch/_dynamo/__pycache__/config.cpython-38.pyc\x94\x8c\r_initializing\x94\x89ub\x8c\x08__file__\x94h\x0e\x8c\n__cached__\x94h\x1b\x8c\x07abspath\x94\x8c\tposixpath\x94h\x1f\x93\x94\x8c\x07dirname\x94h h"\x93\x94\x8c\x0eHAS_REFS_PRIMS\x94\x88\x8c\tlog_level\x94K\x1e\x8c\x0boutput_code\x94\x89\x8c\rlog_file_name\x94N\x8c\x07verbose\x94\x89\x8c\x11output_graph_code\x94\x89\x8c\x12verify_correctness\x94\x89\x8c\x12minimum_call_count\x94K\x01\x8c\x15dead_code_elimination\x94\x88\x8c\x10cache_size_limit\x94K@\x8c\x14specialize_int_float\x94\x88\x8c\x0edynamic_shapes\x94\x89\x8c\x10guard_nn_modules\x94\x89\x8c\x0cnormalize_ir\x94\x89\x8c\x1btraceable_tensor_subclasses\x94\x8f\x94\x8c\x0fsuppress_errors\x94\x89\x8c\x15replay_record_enabled\x94\x89\x8c rewrite_assert_with_torch_assert\x94\x88\x8c\x12print_graph_breaks\x94\x89\x8c\x07disable\x94\x89\x8c*allowed_functions_module_string_ignorelist\x94\x8f\x94(\x8c\rtorch._decomp\x94\x8c\rtorch.testing\x94\x8c\x13torch.distributions\x94\x8c\x0btorch._refs\x94\x8c\x0ctorch._prims\x94\x90\x8c\x16capture_scalar_outputs\x94\x89\x8c\x19enforce_cond_guards_match\x94\x88\x8c\x0coptimize_ddp\x94\x88\x8c\x1araise_on_ctx_manager_usage\x94\x88\x8c\x1craise_on_unsafe_aot_autograd\x94\x89\x8c\rdynamo_import\x94\x8c\rtorch._dynamo\x94\x8c\x0finductor_import\x94\x8c\x0ftorch._inductor\x94\x8c\x18error_on_nested_fx_trace\x94\x88\x8c\x08base_dir\x94\x8cK/nvme/data/zhaoqian/miniconda3/envs/dynamo-test/lib/python3.8/site-packages\x94\x8c\x0edebug_dir_root\x94\x8cH/nvme/data/zhaoqian/projects/Dynamo/mmclassification/torch_compile_debug\x94\x8c)DO_NOT_USE_legacy_non_fake_example_inputs\x94\x89\x8c\x15_AccessLimitingConfig\x94}\x94(\x8c\n__module__\x94h\x02\x8c\x0b__setattr__\x94h\x02\x8c!_AccessLimitingConfig.__setattr__\x94\x93\x94h\x03Nu\x8c\x15_allowed_config_names\x94\x8f\x94(hIh\x0fh\x03h$h+h\'\x8c\x02os\x94\x8c\x12constant_functions\x94h8\x8c\x0c__builtins__\x94h2hJh5hD\x8c\x03sys\x94h)h,\x8c\x0eexternal_utils\x94h\x1f\x8c!skipfiles_inline_module_allowlist\x94\x8c\x0brepro_level\x94hOh\x04\x8c\x07logging\x94h(h*h-h/hCh0hAhBhNh@h\x1eh6h\x06h4h&h9hLhGh.h%h"h1\x8c\x05torch\x94h\x01\x8c\nModuleType\x94h7h\x1d\x8c\x0brepro_after\x94hE\x90\x8c\x1cget_config_serialization_fns\x94\x8c\x1atorch._dynamo.config_utils\x94hb\x93\x94u.') torch._inductor.config.load_config(b'\x80\x04\x95X\t\x00\x00\x00\x00\x00\x00}\x94(\x8c\x08__name__\x94\x8c\x16torch._inductor.config\x94\x8c\x07__doc__\x94N\x8c\x0b__package__\x94\x8c\x0ftorch._inductor\x94\x8c\n__loader__\x94\x8c\x1a_frozen_importlib_external\x94\x8c\x10SourceFileLoader\x94\x93\x94)\x81\x94}\x94(\x8c\x04name\x94h\x02\x8c\x04path\x94\x8ce/nvme/data/zhaoqian/miniconda3/envs/dynamo-test/lib/python3.8/site-packages/torch/_inductor/config.py\x94ub\x8c\x08__spec__\x94\x8c\x11_frozen_importlib\x94\x8c\nModuleSpec\x94\x93\x94)\x81\x94}\x94(h\x0ch\x02\x8c\x06loader\x94h\n\x8c\x06origin\x94h\x0e\x8c\x0cloader_state\x94N\x8c\x1asubmodule_search_locations\x94N\x8c\r_set_fileattr\x94\x88\x8c\x07_cached\x94\x8c}/nvme/data/zhaoqian/miniconda3/envs/dynamo-test/lib/python3.8/site-packages/torch/_inductor/__pycache__/config.cpython-38.pyc\x94\x8c\r_initializing\x94\x89ub\x8c\x08__file__\x94h\x0e\x8c\n__cached__\x94h\x1b\x8c\x05debug\x94\x89\x8c\x10disable_progress\x94\x88\x8c\x10verbose_progress\x94\x89\x8c\x0bcpp_wrapper\x94\x89\x8c\x03dce\x94\x89\x8c\x0edynamic_shapes\x94\x89\x8c\x14static_weight_shapes\x94\x88\x8c\x0csize_asserts\x94\x88\x8c\x10pick_loop_orders\x94\x88\x8c\x0finplace_buffers\x94\x88\x8c\x11benchmark_harness\x94\x88\x8c\x17realize_reads_threshold\x94K\x04\x8c\x17realize_bytes_threshold\x94M\xd0\x07\x8c\x1brealize_acc_reads_threshold\x94K\x08\x8c\x0ffallback_random\x94\x89\x8c\x12implicit_fallbacks\x94\x88\x8c\rprefuse_nodes\x94\x88\x8c\x0btune_layout\x94\x89\x8c\x11aggressive_fusion\x94\x89\x8c\x0fmax_fusion_size\x94K@\x8c\x1bunroll_reductions_threshold\x94K\x08\x8c\x0ecomment_origin\x94\x89\x8c\tis_fbcode\x94h\x02h5\x93\x94\x8c\x0fcompile_threads\x94K \x8c\x13kernel_name_max_ops\x94K\n\x8c\x0finductor_import\x94\x8c\x0ftorch._inductor\x94\x8c\rdynamo_import\x94\x8c\rtorch._dynamo\x94\x8c\rshape_padding\x94\x89\x8c\x0epermute_fusion\x94\x89\x8c\x1aprofiler_mark_wrapper_call\x94\x89\x8c\x03cpp\x94}\x94(\x8c\n__module__\x94h\x02\x8c\x07threads\x94J\xff\xff\xff\xff\x8c\x0fdynamic_threads\x94\x89\x8c\x07simdlen\x94N\x8c\x0emin_chunk_size\x94M\x00\x10\x8c\x03cxx\x94N\x8c\x03g++\x94\x86\x94\x8c\x15enable_kernel_profile\x94\x89h\x03Nu\x8c\x06triton\x94}\x94(hBh\x02\x8c\ncudagraphs\x94\x88\x8c\x10debug_sync_graph\x94\x89\x8c\x11debug_sync_kernel\x94\x89\x8c\x0bconvolution\x94\x8c\x04aten\x94\x8c\x02mm\x94hQ\x8c\x0edense_indexing\x94\x89\x8c\tmax_tiles\x94K\x02\x8c\x08autotune\x94\x88\x8c\x07use_bmm\x94\x89\x8c tiling_prevents_pointwise_fusion\x94\x88\x8c tiling_prevents_reduction_fusion\x94\x88\x8c\x14ordered_kernel_names\x94\x89\x8c\x18descriptive_kernel_names\x94\x88h\x03Nu\x8c\x05trace\x94}\x94(hBh\x02\x8c\x07enabled\x94\x89\x8c\tdebug_log\x94\x88\x8c\x08info_log\x94\x89\x8c\x08fx_graph\x94\x88\x8c\rir_pre_fusion\x94\x88\x8c\x0eir_post_fusion\x94\x88\x8c\x0boutput_code\x94\x88\x8c\rgraph_diagram\x94\x89\x8c\x0fcompile_profile\x94\x89\x8c\nupload_tar\x94Nh\x03Nu\x8c\x15InductorConfigContext\x94}\x94(hBh\x02\x8c\x0f__annotations__\x94}\x94(\x8c\rstatic_memory\x94\x8c\x08builtins\x94\x8c\x04bool\x94\x93\x94\x8c\x0bmatmul_tune\x94hl\x8c\x03str\x94\x93\x94\x8c\x0ematmul_padding\x94hn\x8c\x0ftriton_autotune\x94hn\x8c\ntriton_bmm\x94hn\x8c\ttriton_mm\x94hq\x8c\x12triton_convolution\x94hq\x8c\x17rematerialize_threshold\x94hl\x8c\x03int\x94\x93\x94\x8c\x1brematerialize_acc_threshold\x94hyu\x8c\x05_save\x94h\x02\x8c\x1bInductorConfigContext._save\x94\x93\x94\x8c\x06_apply\x94h\x02\x8c\x1cInductorConfigContext._apply\x94\x93\x94\x8c\x08__init__\x94h\x02\x8c\x1eInductorConfigContext.__init__\x94\x93\x94\x8c\t__enter__\x94h\x02\x8c\x1fInductorConfigContext.__enter__\x94\x93\x94\x8c\x08__exit__\x94h\x02\x8c\x1eInductorConfigContext.__exit__\x94\x93\x94h\x03Nu\x8c\x1cget_config_serialization_fns\x94\x8c\x1atorch._dynamo.config_utils\x94h\x8a\x93\x94u.') # REPLACEABLE COMMENT FOR TESTING PURPOSES # torch version: 2.0.0.dev20230109+cu116 # torch cuda version: 11.6 # torch git version: 65ff52a53b468266f7fdd93069ff04532f4d002d # CUDA Info: # nvcc: NVIDIA (R) Cuda compiler driver # Copyright (c) 2005-2022 NVIDIA Corporation # Built on Tue_Mar__8_18:18:20_PST_2022 # Cuda compilation tools, release 11.6, V11.6.124 # Build cuda_11.6.r11.6/compiler.31057947_0 # GPU Hardware Info: # NVIDIA A100-SXM4-80GB : 8 from torch.nn import * class Repro(torch.nn.Module): def __init__(self): super().__init__() def forward(self, tangents_1): div_1 = torch.ops.aten.div.Tensor(tangents_1, 1); tangents_1 = None return [div_1] args = [((), (), torch.float32, 'cuda')] args = [rand_strided(sh, st, dt, dev) for (sh, st, dt, dev) in args] mod = make_fx(Repro())(*args) from functools import partial from torch._dynamo.debug_utils import ( isolate_fails, dump_compiler_graph_state, ) from functorch.compile import minifier env_variables = {"CUDA_VISIBLE_DEVICES": "1"} minifier( mod, args, module_fails=partial(isolate_fails, env=env_variables, compiler_name="inductor", patch_code=isolate_fails_code_str), dump_state=partial(dump_compiler_graph_state, compiler_name="inductor"), ) ```

And the temporary file is shown as below:

Temporory file ```python from ctypes import c_void_p, c_long import torch import random from torch import empty_strided, as_strided, device from torch._inductor.codecache import AsyncCompile aten = torch.ops.aten assert_size_stride = torch._C._dynamo.guards.assert_size_stride async_compile = AsyncCompile() import triton import triton.language as tl from torch._inductor.triton_ops.autotune import grid from torch._C import _cuda_getCurrentRawStream as get_cuda_stream triton_fused_div_1_0 = async_compile.triton(''' import triton import triton.language as tl from torch._inductor.ir import ReductionHint from torch._inductor.ir import TileHint from torch._inductor.triton_ops.autotune import pointwise from torch._inductor.utils import instance_descriptor @pointwise(size_hints=[1], filename=__file__, meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': [], 'configs': [instance_descriptor(divisible_by_16=(0, 1), equal_to_1=())]}) @triton.jit def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): xnumel = 1 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:] xmask = xindex < xnumel tmp0 = tl.load(in_ptr0 + (0 + tl.zeros([XBLOCK], tl.int32)), None) tmp1 = 1 tmp2 = tmp0 / tmp1 tl.store(out_ptr0 + (0 + tl.zeros([XBLOCK], tl.int32)), tmp2, None) ''') async_compile.wait(globals()) del async_compile def call(args): tangents_1, = args args.clear() with torch.cuda.device(0): buf0 = empty_strided((), (), device='cuda', dtype=torch.float32) stream0 = get_cuda_stream(0) triton_fused_div_1_0.run(tangents_1, buf0, 1, grid=grid(1), stream=stream0) del tangents_1 return (buf0, ) if __name__ == "__main__": from torch._dynamo.testing import rand_strided from torch._inductor.utils import print_performance tangents_1 = rand_strided((), (), device='cuda:0', dtype=torch.float32) print_performance(lambda: call([tangents_1])) ```

Hope these files can help.

ngimel commented 1 year ago

Easy repro

import torch
import torch._inductor.config as config

def fn(x):
    return x/3

opt_fn = torch.compile(fn)
x=torch.randn(4, device="cuda", requires_grad=True)
gO = torch.rand_like(x)
out = opt_fn(x)
out.backward(gO)
ngimel commented 1 year ago

Can you try patching https://github.com/pytorch/pytorch/pull/92055?

C1rN09 commented 1 year ago

Can you try patching pytorch/pytorch#92055?

Yes, the patch solves this issue :smile: