Hi, I am trying to compile the resnet34 model using PyTorch's 2.0 max-tune mode. However, it fails with the following error
Code
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet34', pretrained=True).to(device="cuda:0")
x = torch.randn(512,3,224,224).to(device="cuda:0") # Random batch
opt_model_max = torch.compile(model, mode="max-autotune", backend="inductor")
for i in range(10):
begin = time.time()
output = opt_model_max(x)
time.sleep(1)
end = time.time()
print(f"Time for {i}-th forward pass is {end - begin -1}")
Error
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
File /opt/conda/lib/python3.10/site-packages/torch/_inductor/graph.py:295, in GraphLowering.call_function(self, target, args, kwargs)
294 try:
--> 295 out = lowerings[target](*args, **kwargs)
296 return out
File /opt/conda/lib/python3.10/site-packages/torch/_inductor/lowering.py:222, in _register_lowering.<locals>.wrapped(*args, **kwargs)
218 args[i] = ExpandView.create(
219 args[i], list(args[indices[0]].get_size())
220 )
--> 222 return decomp_fn(*args, **kwargs)
File /opt/conda/lib/python3.10/site-packages/torch/_inductor/lowering.py:842, in mm(a, b)
840 @register_lowering(aten.mm)
841 def mm(a: TensorBox, b: TensorBox):
--> 842 return TensorBox.create(ir.MatrixMultiply.create(a, b))
File /opt/conda/lib/python3.10/site-packages/torch/_inductor/ir.py:2803, in MatrixMultiply.create(cls, a, b)
2801 from .codegen.autotuner import tuned_mm
-> 2803 kernel = tuned_mm(
2804 a.get_size(),
2805 b.get_size(),
2806 a.get_stride(),
2807 b.get_stride(),
2808 a.get_device(),
2809 a.get_dtype(),
2810 )
2812 return MatrixMultiply(
2813 layout=FlexibleLayout(
2814 device=a.get_device(),
(...)
2819 kernel=kernel,
2820 )
File /opt/conda/lib/python3.10/site-packages/torch/_inductor/codegen/autotuner.py:202, in tuned_mm(a_shape, b_shape, a_stride, b_stride, device, dtype, adjust_triton)
201 run_kwargs = {"out": c}
--> 202 timing, _, _ = autotune._bench(runnable_kernel, *run_args, **run_kwargs)
203 if "triton_ops" in kernel:
File /opt/conda/lib/python3.10/site-packages/torch/_inductor/codegen/autotuner.py:40, in Autotuner._bench(self, kernel, *args, **kwargs)
38 from triton.testing import do_bench
---> 40 return do_bench(kernel_call)
File /opt/conda/lib/python3.10/site-packages/triton/testing.py:140, in do_bench(fn, warmup, rep, grad_to_none, percentiles, record_clocks, fast_flush)
139 # Estimate the runtime of the function
--> 140 fn()
141 torch.cuda.synchronize()
File /opt/conda/lib/python3.10/site-packages/torch/_inductor/codegen/autotuner.py:36, in Autotuner._bench.<locals>.kernel_call()
35 def kernel_call():
---> 36 kernel(*args, **kwargs)
File /opt/conda/lib/python3.10/site-packages/torch/_inductor/triton_ops/matmul.py:134, in _matmul_out.forward(a, b, out, allow_tf32)
132 @staticmethod
133 def forward(a, b, out, allow_tf32=True):
--> 134 return _matmul_out._call(a, b, out, allow_tf32)
File /opt/conda/lib/python3.10/site-packages/torch/_inductor/triton_ops/matmul.py:114, in _matmul_out._call(a, b, out, allow_tf32)
110 # grid = lambda META: (
111 # triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
112 # META["SPLIT_K"],
113 # )
--> 114 _kernel[grid](
115 a,
116 b,
117 c,
118 M,
119 N,
120 K,
121 a.stride(0),
122 a.stride(1),
123 b.stride(0),
124 b.stride(1),
125 c.stride(0),
126 c.stride(1),
127 allow_tf32=allow_tf32,
128 GROUP_M=8,
129 ACC_TYPE=ACC_TYPE,
130 )
File /opt/conda/lib/python3.10/site-packages/triton/runtime/jit.py:106, in KernelInterface.__getitem__.<locals>.launcher(*args, **kwargs)
105 def launcher(*args, **kwargs):
--> 106 return self.run(*args, grid=grid, **kwargs)
File /opt/conda/lib/python3.10/site-packages/triton/runtime/autotuner.py:199, in Heuristics.run(self, *args, **kwargs)
198 for v, heur in self.values.items():
--> 199 kwargs[v] = heur({**dict(zip(self.arg_names, args)), **kwargs})
200 return self.fn.run(*args, **kwargs)
File /opt/conda/lib/python3.10/site-packages/torch/_inductor/triton_ops/autotune.py:560, in mm_heuristics.<locals>.<lambda>(args)
556 from triton import heuristics
558 mm_heuristic = heuristics(
559 {
--> 560 "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0,
561 }
562 )
563 return mm_heuristic
KeyError: 'BLOCK_K'
The above exception was the direct cause of the following exception:
LoweringException Traceback (most recent call last)
File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/output_graph.py:676, in OutputGraph.call_user_compiler(self, gm)
675 else:
--> 676 compiled_fn = compiler_fn(gm, self.fake_example_inputs())
677 _step_logger()(logging.INFO, f"done compiler function {name}")
File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/debug_utils.py:945, in wrap_backend_debug.<locals>.debug_wrapper(gm, example_inputs, **kwargs)
944 else:
--> 945 compiled_gm = compiler_fn(gm, example_inputs, **kwargs)
947 return compiled_gm
File /opt/conda/lib/python3.10/site-packages/torch/__init__.py:1153, in _TorchCompileInductorWrapper.__call__(self, model_, inputs_)
1152 with self.cm:
-> 1153 return self.compile_fn(model_, inputs_)
File /opt/conda/lib/python3.10/site-packages/torch/_inductor/compile_fx.py:398, in compile_fx(model_, example_inputs_, inner_compile)
393 with overrides.patch_functions():
394
395 # TODO: can add logging before/after the call to create_aot_dispatcher_function
396 # in torch._functorch/aot_autograd.py::aot_module_simplified::aot_function_simplified::new_func
397 # once torchdynamo is merged into pytorch
--> 398 return aot_autograd(
399 fw_compiler=fw_compiler,
400 bw_compiler=bw_compiler,
401 decompositions=select_decomp_table(),
402 partition_fn=functools.partial(
403 min_cut_rematerialization_partition, compiler="inductor"
404 ),
405 )(model_, example_inputs_)
File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/optimizations/training.py:78, in aot_autograd.<locals>.compiler_fn(gm, example_inputs)
77 with enable_aot_logging():
---> 78 cg = aot_module_simplified(gm, example_inputs, **kwargs)
79 counters["aot_autograd"]["ok"] += 1
File /opt/conda/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py:2353, in aot_module_simplified(mod, args, fw_compiler, bw_compiler, partition_fn, decompositions, hasher_type, static_argnums)
2351 full_args.extend(args)
-> 2353 compiled_fn = create_aot_dispatcher_function(
2354 functional_call,
2355 full_args,
2356 aot_config,
2357 )
2359 # TODO: There is something deeply wrong here; compiled_fn running with
2360 # the boxed calling convention, but aot_module_simplified somehow
2361 # historically returned a function that was not the boxed calling
2362 # convention. This should get fixed...
File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/utils.py:90, in dynamo_timed.<locals>.time_wrapper(*args, **kwargs)
89 t0 = time.time()
---> 90 r = func(*args, **kwargs)
91 latency = time.time() - t0
File /opt/conda/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py:2050, in create_aot_dispatcher_function(flat_fn, flat_args, aot_config)
2048 # You can put more passes here
-> 2050 compiled_fn = compiler_fn(flat_fn, fake_flat_tensor_args, aot_config)
2052 if not hasattr(compiled_fn, '_boxed_call'):
File /opt/conda/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py:1305, in aot_wrapper_dedupe(flat_fn, flat_args, aot_config, compiler_fn)
1304 if ok:
-> 1305 return compiler_fn(flat_fn, leaf_flat_args, aot_config)
1307 # Strategy 2: Duplicate specialize.
1308 #
1309 # In Haskell types, suppose you have:
(...)
1341 # }
1342 # keep_arg_mask = [True, True, False, True]
File /opt/conda/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py:1564, in aot_dispatch_autograd(flat_fn, flat_args, aot_config)
1563 with track_graph_compiling(aot_config, "forward"):
-> 1564 compiled_fw_func = aot_config.fw_compiler(
1565 fw_module, flat_args_with_views_handled
1566 )
1568 class CompiledFunction(torch.autograd.Function):
File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/utils.py:90, in dynamo_timed.<locals>.time_wrapper(*args, **kwargs)
89 t0 = time.time()
---> 90 r = func(*args, **kwargs)
91 latency = time.time() - t0
File /opt/conda/lib/python3.10/site-packages/torch/_inductor/compile_fx.py:373, in compile_fx.<locals>.fw_compiler(model, example_inputs)
372 fixed = len(example_inputs) - num_example_inputs
--> 373 return inner_compile(
374 model,
375 example_inputs,
376 num_fixed=fixed,
377 cudagraphs=cudagraphs,
378 graph_id=graph_id,
379 )
File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/debug_utils.py:507, in wrap_compiler_debug.<locals>.debug_wrapper(gm, example_inputs, **kwargs)
506 else:
--> 507 compiled_fn = compiler_fn(gm, example_inputs, **kwargs)
509 return compiled_fn
File /opt/conda/lib/python3.10/site-packages/torch/_inductor/debug.py:223, in DebugContext.wrap.<locals>.inner(*args, **kwargs)
222 with DebugContext():
--> 223 return fn(*args, **kwargs)
File /opt/conda/lib/python3.10/contextlib.py:79, in ContextDecorator.__call__.<locals>.inner(*args, **kwds)
78 with self._recreate_cm():
---> 79 return func(*args, **kwds)
File /opt/conda/lib/python3.10/site-packages/torch/_inductor/compile_fx.py:139, in compile_fx_inner(gm, example_inputs, cudagraphs, num_fixed, is_backward, graph_id)
138 with V.set_graph_handler(graph):
--> 139 graph.run(*example_inputs)
140 compiled_fn = graph.compile_to_fn()
File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/utils.py:90, in dynamo_timed.<locals>.time_wrapper(*args, **kwargs)
89 t0 = time.time()
---> 90 r = func(*args, **kwargs)
91 latency = time.time() - t0
File /opt/conda/lib/python3.10/site-packages/torch/_inductor/graph.py:169, in GraphLowering.run(self, *args)
167 @dynamo_utils.dynamo_timed
168 def run(self, *args):
--> 169 return super().run(*args)
File /opt/conda/lib/python3.10/site-packages/torch/fx/interpreter.py:130, in Interpreter.run(self, initial_env, enable_io_processing, *args)
129 try:
--> 130 self.env[node] = self.run_node(node)
131 except Exception as e:
File /opt/conda/lib/python3.10/site-packages/torch/_inductor/graph.py:368, in GraphLowering.run_node(self, n)
367 else:
--> 368 result = super().run_node(n)
370 # Realize if (1) any user need inputs realized, or (2) there is
371 # already too many reads and rematerializing can be bad.
File /opt/conda/lib/python3.10/site-packages/torch/fx/interpreter.py:171, in Interpreter.run_node(self, n)
170 assert isinstance(kwargs, dict)
--> 171 return getattr(self, n.op)(n.target, args, kwargs)
File /opt/conda/lib/python3.10/site-packages/torch/_inductor/graph.py:298, in GraphLowering.call_function(self, target, args, kwargs)
297 except Exception as e:
--> 298 raise LoweringException(e, target, args, kwargs) from e
LoweringException: KeyError: 'BLOCK_K'
target: aten.mm.default
args[0]: TensorBox(StorageBox(
ComputedBuffer(name='buf240', layout=FixedLayout('cuda', torch.float32, size=(512, 512), stride=[512, 1]), data=Pointwise(
'cuda',
torch.float32,
tmp0 = load(buf239, i1 + 512 * i0)
tmp1 = index_expr(49, torch.float32)
tmp2 = tmp0 / tmp1
return tmp2
,
ranges=(512, 512),
origins={view}
))
))
args[1]: TensorBox(
ReinterpretView(
StorageBox(
InputBuffer(name='primals_109', layout=FixedLayout('cuda', torch.float32, size=[1000, 512], stride=[512, 1]))
),
FixedLayout('cuda', torch.float32, size=[512, 1000], stride=[1, 512]),
no origins?
)
)
While executing %mm : [#users=1] = call_function[target=torch.ops.aten.mm.default](args = (%view, %permute), kwargs = {})
Original traceback:
Module stack: {'self_fc': "<class 'torch.nn.modules.linear.Linear'>"}
File "/root/.cache/torch/hub/pytorch_vision_v0.10.0/torchvision/models/resnet.py", line 244, in _forward_impl
x = self.fc(x)
| File "/root/.cache/torch/hub/pytorch_vision_v0.10.0/torchvision/models/resnet.py", line 249, in forward
return self._forward_impl(x)
Gradient addition node due to multiple use of tensor around:
Module stack: {'self_conv1': "<class 'torch.nn.modules.conv.Conv2d'>"}
File "/root/.cache/torch/hub/pytorch_vision_v0.10.0/torchvision/models/resnet.py", line 232, in _forward_impl
x = self.conv1(x)
| File "/root/.cache/torch/hub/pytorch_vision_v0.10.0/torchvision/models/resnet.py", line 249, in forward
return self._forward_impl(x)
The above exception was the direct cause of the following exception:
BackendCompilerFailed Traceback (most recent call last)
Cell In[5], line 11
9 for i in range(10):
10 begin = time.time()
---> 11 output = opt_model_max(x)
12 time.sleep(1)
13 end = time.time()
File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1482, in Module._call_impl(self, *args, **kwargs)
1477 # If we don't have any hooks, we want to skip the rest of the logic in
1478 # this function, and just call forward.
1479 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1480 or _global_backward_pre_hooks or _global_backward_hooks
1481 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1482 return forward_call(*args, **kwargs)
1483 # Do not call functions when jit is used
1484 full_backward_hooks, non_full_backward_hooks = [], []
File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:83, in OptimizedModule.forward(self, *args, **kwargs)
82 def forward(self, *args, **kwargs):
---> 83 return self.dynamo_ctx(self._orig_mod.forward)(*args, **kwargs)
File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:212, in _TorchDynamoContext.__call__.<locals>._fn(*args, **kwargs)
210 dynamic_ctx.__enter__()
211 try:
--> 212 return fn(*args, **kwargs)
213 finally:
214 set_eval_frame(prior)
File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:333, in catch_errors_wrapper.<locals>.catch_errors(frame, cache_size)
330 return hijacked_callback(frame, cache_size, hooks)
332 with compile_lock:
--> 333 return callback(frame, cache_size, hooks)
File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:480, in convert_frame.<locals>._convert_frame(frame, cache_size, hooks)
478 counters["frames"]["total"] += 1
479 try:
--> 480 result = inner_convert(frame, cache_size, hooks)
481 counters["frames"]["ok"] += 1
482 return result
File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:103, in wrap_convert_context.<locals>._fn(*args, **kwargs)
101 torch.fx.graph_module._forward_from_src = fx_forward_from_src_skip_result
102 try:
--> 103 return fn(*args, **kwargs)
104 finally:
105 torch._C._set_grad_enabled(prior_grad_mode)
File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/utils.py:90, in dynamo_timed.<locals>.time_wrapper(*args, **kwargs)
88 compilation_metrics[key] = []
89 t0 = time.time()
---> 90 r = func(*args, **kwargs)
91 latency = time.time() - t0
92 # print(f"Dynamo timer: key={key}, latency={latency:.2f} sec")
File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:339, in convert_frame_assert.<locals>._convert_frame_assert(frame, cache_size, hooks)
336 global initial_grad_state
337 initial_grad_state = torch.is_grad_enabled()
--> 339 return _compile(
340 frame.f_code,
341 frame.f_globals,
342 frame.f_locals,
343 frame.f_builtins,
344 compiler_fn,
345 one_graph,
346 export,
347 hooks,
348 frame,
349 )
File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:400, in _compile(code, globals, locals, builtins, compiler_fn, one_graph, export, hooks, frame)
398 for attempt in itertools.count():
399 try:
--> 400 out_code = transform_code_object(code, transform)
401 orig_code_map[out_code] = code
402 break
File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py:341, in transform_code_object(code, transformations, safe)
338 instructions = cleaned_instructions(code, safe)
339 propagate_line_nums(instructions)
--> 341 transformations(instructions, code_options)
343 fix_vars(instructions, code_options)
345 dirty = True
File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:387, in _compile.<locals>.transform(instructions, code_options)
374 nonlocal output
375 tracer = InstructionTranslator(
376 instructions,
377 code,
(...)
385 mutated_closure_cell_contents,
386 )
--> 387 tracer.run()
388 output = tracer.output
389 assert output is not None
File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:1684, in InstructionTranslator.run(self)
1682 def run(self):
1683 _step_logger()(logging.INFO, f"torchdynamo start tracing {self.f_code.co_name}")
-> 1684 super().run()
File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:538, in InstructionTranslatorBase.run(self)
533 try:
534 self.output.push_tx(self)
535 while (
536 self.instruction_pointer is not None
537 and not self.output.should_exit
--> 538 and self.step()
539 ):
540 pass
541 except BackendCompilerFailed:
File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:501, in InstructionTranslatorBase.step(self)
499 if not hasattr(self, inst.opname):
500 unimplemented(f"missing: {inst.opname}")
--> 501 getattr(self, inst.opname)(inst)
503 return inst.opname != "RETURN_VALUE"
504 except BackendCompilerFailed:
File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:1750, in InstructionTranslator.RETURN_VALUE(self, inst)
1745 _step_logger()(
1746 logging.INFO,
1747 f"torchdynamo done tracing {self.f_code.co_name} (RETURN_VALUE)",
1748 )
1749 log.debug("RETURN_VALUE triggered compile")
-> 1750 self.output.compile_subgraph(self)
1751 self.output.add_output_instructions([create_instruction("RETURN_VALUE")])
File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/output_graph.py:529, in OutputGraph.compile_subgraph(self, tx, partial_convert, reason)
512 self.add_output_instructions(random_calls_instructions)
514 if (
515 stack_values
516 and all(
(...)
526
527 # optimization to generate better code in a common case
528 self.add_output_instructions(
--> 529 self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
530 + [create_instruction("UNPACK_SEQUENCE", len(stack_values))]
531 )
532 else:
533 graph_output_var = self.new_var("graph_out")
File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/output_graph.py:600, in OutputGraph.compile_and_call_fx_graph(self, tx, rv, root)
598 assert_no_fake_params_or_buffers(gm)
599 with tracing(self.tracing_context):
--> 600 compiled_fn = self.call_user_compiler(gm)
601 compiled_fn = disable(compiled_fn)
603 counters["stats"]["unique_graphs"] += 1
File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/output_graph.py:681, in OutputGraph.call_user_compiler(self, gm)
679 except Exception as e:
680 compiled_fn = gm.forward
--> 681 raise BackendCompilerFailed(self.compiler_fn, e) from e
682 return compiled_fn
BackendCompilerFailed: debug_wrapper raised LoweringException: KeyError: 'BLOCK_K'
target: aten.mm.default
args[0]: TensorBox(StorageBox(
ComputedBuffer(name='buf240', layout=FixedLayout('cuda', torch.float32, size=(512, 512), stride=[512, 1]), data=Pointwise(
'cuda',
torch.float32,
tmp0 = load(buf239, i1 + 512 * i0)
tmp1 = index_expr(49, torch.float32)
tmp2 = tmp0 / tmp1
return tmp2
,
ranges=(512, 512),
origins={view}
))
))
args[1]: TensorBox(
ReinterpretView(
StorageBox(
InputBuffer(name='primals_109', layout=FixedLayout('cuda', torch.float32, size=[1000, 512], stride=[512, 1]))
),
FixedLayout('cuda', torch.float32, size=[512, 1000], stride=[1, 512]),
no origins?
)
)
While executing %mm : [#users=1] = call_function[target=torch.ops.aten.mm.default](args = (%view, %permute), kwargs = {})
Original traceback:
Module stack: {'self_fc': "<class 'torch.nn.modules.linear.Linear'>"}
File "/root/.cache/torch/hub/pytorch_vision_v0.10.0/torchvision/models/resnet.py", line 244, in _forward_impl
x = self.fc(x)
| File "/root/.cache/torch/hub/pytorch_vision_v0.10.0/torchvision/models/resnet.py", line 249, in forward
return self._forward_impl(x)
Gradient addition node due to multiple use of tensor around:
Module stack: {'self_conv1': "<class 'torch.nn.modules.conv.Conv2d'>"}
File "/root/.cache/torch/hub/pytorch_vision_v0.10.0/torchvision/models/resnet.py", line 232, in _forward_impl
x = self.conv1(x)
| File "/root/.cache/torch/hub/pytorch_vision_v0.10.0/torchvision/models/resnet.py", line 249, in forward
return self._forward_impl(x)
Set torch._dynamo.config.verbose=True for more information
You can suppress this exception and fall back to eager by setting:
torch._dynamo.config.suppress_errors = True
Hi, I am trying to compile the resnet34 model using PyTorch's 2.0 max-tune mode. However, it fails with the following error Code
Error