pytorch / torchdynamo

A Python-level JIT compiler designed to make unmodified PyTorch programs faster.
BSD 3-Clause "New" or "Revised" License
1.01k stars 124 forks source link

PT2.0 max-autotune mode fails for resnet34 #2022

Closed Puneet2000 closed 1 year ago

Puneet2000 commented 1 year ago

Hi, I am trying to compile the resnet34 model using PyTorch's 2.0 max-tune mode. However, it fails with the following error Code

model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet34', pretrained=True).to(device="cuda:0")
x = torch.randn(512,3,224,224).to(device="cuda:0") # Random batch
opt_model_max = torch.compile(model, mode="max-autotune", backend="inductor")

for i in range(10):
    begin = time.time()
    output = opt_model_max(x)
    time.sleep(1)
    end = time.time()
    print(f"Time for {i}-th forward pass is {end - begin -1}")

Error

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
File /opt/conda/lib/python3.10/site-packages/torch/_inductor/graph.py:295, in GraphLowering.call_function(self, target, args, kwargs)
    294 try:
--> 295     out = lowerings[target](*args, **kwargs)
    296     return out

File /opt/conda/lib/python3.10/site-packages/torch/_inductor/lowering.py:222, in _register_lowering.<locals>.wrapped(*args, **kwargs)
    218             args[i] = ExpandView.create(
    219                 args[i], list(args[indices[0]].get_size())
    220             )
--> 222 return decomp_fn(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/_inductor/lowering.py:842, in mm(a, b)
    840 @register_lowering(aten.mm)
    841 def mm(a: TensorBox, b: TensorBox):
--> 842     return TensorBox.create(ir.MatrixMultiply.create(a, b))

File /opt/conda/lib/python3.10/site-packages/torch/_inductor/ir.py:2803, in MatrixMultiply.create(cls, a, b)
   2801     from .codegen.autotuner import tuned_mm
-> 2803     kernel = tuned_mm(
   2804         a.get_size(),
   2805         b.get_size(),
   2806         a.get_stride(),
   2807         b.get_stride(),
   2808         a.get_device(),
   2809         a.get_dtype(),
   2810     )
   2812 return MatrixMultiply(
   2813     layout=FlexibleLayout(
   2814         device=a.get_device(),
   (...)
   2819     kernel=kernel,
   2820 )

File /opt/conda/lib/python3.10/site-packages/torch/_inductor/codegen/autotuner.py:202, in tuned_mm(a_shape, b_shape, a_stride, b_stride, device, dtype, adjust_triton)
    201     run_kwargs = {"out": c}
--> 202 timing, _, _ = autotune._bench(runnable_kernel, *run_args, **run_kwargs)
    203 if "triton_ops" in kernel:

File /opt/conda/lib/python3.10/site-packages/torch/_inductor/codegen/autotuner.py:40, in Autotuner._bench(self, kernel, *args, **kwargs)
     38 from triton.testing import do_bench
---> 40 return do_bench(kernel_call)

File /opt/conda/lib/python3.10/site-packages/triton/testing.py:140, in do_bench(fn, warmup, rep, grad_to_none, percentiles, record_clocks, fast_flush)
    139 # Estimate the runtime of the function
--> 140 fn()
    141 torch.cuda.synchronize()

File /opt/conda/lib/python3.10/site-packages/torch/_inductor/codegen/autotuner.py:36, in Autotuner._bench.<locals>.kernel_call()
     35 def kernel_call():
---> 36     kernel(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/_inductor/triton_ops/matmul.py:134, in _matmul_out.forward(a, b, out, allow_tf32)
    132 @staticmethod
    133 def forward(a, b, out, allow_tf32=True):
--> 134     return _matmul_out._call(a, b, out, allow_tf32)

File /opt/conda/lib/python3.10/site-packages/torch/_inductor/triton_ops/matmul.py:114, in _matmul_out._call(a, b, out, allow_tf32)
    110 # grid = lambda META: (
    111 #     triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
    112 #     META["SPLIT_K"],
    113 # )
--> 114 _kernel[grid](
    115     a,
    116     b,
    117     c,
    118     M,
    119     N,
    120     K,
    121     a.stride(0),
    122     a.stride(1),
    123     b.stride(0),
    124     b.stride(1),
    125     c.stride(0),
    126     c.stride(1),
    127     allow_tf32=allow_tf32,
    128     GROUP_M=8,
    129     ACC_TYPE=ACC_TYPE,
    130 )

File /opt/conda/lib/python3.10/site-packages/triton/runtime/jit.py:106, in KernelInterface.__getitem__.<locals>.launcher(*args, **kwargs)
    105 def launcher(*args, **kwargs):
--> 106     return self.run(*args, grid=grid, **kwargs)

File /opt/conda/lib/python3.10/site-packages/triton/runtime/autotuner.py:199, in Heuristics.run(self, *args, **kwargs)
    198 for v, heur in self.values.items():
--> 199     kwargs[v] = heur({**dict(zip(self.arg_names, args)), **kwargs})
    200 return self.fn.run(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/_inductor/triton_ops/autotune.py:560, in mm_heuristics.<locals>.<lambda>(args)
    556 from triton import heuristics
    558 mm_heuristic = heuristics(
    559     {
--> 560         "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0,
    561     }
    562 )
    563 return mm_heuristic

KeyError: 'BLOCK_K'

The above exception was the direct cause of the following exception:

LoweringException                         Traceback (most recent call last)
File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/output_graph.py:676, in OutputGraph.call_user_compiler(self, gm)
    675 else:
--> 676     compiled_fn = compiler_fn(gm, self.fake_example_inputs())
    677 _step_logger()(logging.INFO, f"done compiler function {name}")

File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/debug_utils.py:945, in wrap_backend_debug.<locals>.debug_wrapper(gm, example_inputs, **kwargs)
    944 else:
--> 945     compiled_gm = compiler_fn(gm, example_inputs, **kwargs)
    947 return compiled_gm

File /opt/conda/lib/python3.10/site-packages/torch/__init__.py:1153, in _TorchCompileInductorWrapper.__call__(self, model_, inputs_)
   1152 with self.cm:
-> 1153     return self.compile_fn(model_, inputs_)

File /opt/conda/lib/python3.10/site-packages/torch/_inductor/compile_fx.py:398, in compile_fx(model_, example_inputs_, inner_compile)
    393 with overrides.patch_functions():
    394 
    395     # TODO: can add logging before/after the call to create_aot_dispatcher_function
    396     # in torch._functorch/aot_autograd.py::aot_module_simplified::aot_function_simplified::new_func
    397     # once torchdynamo is merged into pytorch
--> 398     return aot_autograd(
    399         fw_compiler=fw_compiler,
    400         bw_compiler=bw_compiler,
    401         decompositions=select_decomp_table(),
    402         partition_fn=functools.partial(
    403             min_cut_rematerialization_partition, compiler="inductor"
    404         ),
    405     )(model_, example_inputs_)

File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/optimizations/training.py:78, in aot_autograd.<locals>.compiler_fn(gm, example_inputs)
     77 with enable_aot_logging():
---> 78     cg = aot_module_simplified(gm, example_inputs, **kwargs)
     79     counters["aot_autograd"]["ok"] += 1

File /opt/conda/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py:2353, in aot_module_simplified(mod, args, fw_compiler, bw_compiler, partition_fn, decompositions, hasher_type, static_argnums)
   2351 full_args.extend(args)
-> 2353 compiled_fn = create_aot_dispatcher_function(
   2354     functional_call,
   2355     full_args,
   2356     aot_config,
   2357 )
   2359 # TODO: There is something deeply wrong here; compiled_fn running with
   2360 # the boxed calling convention, but aot_module_simplified somehow
   2361 # historically returned a function that was not the boxed calling
   2362 # convention.  This should get fixed...

File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/utils.py:90, in dynamo_timed.<locals>.time_wrapper(*args, **kwargs)
     89 t0 = time.time()
---> 90 r = func(*args, **kwargs)
     91 latency = time.time() - t0

File /opt/conda/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py:2050, in create_aot_dispatcher_function(flat_fn, flat_args, aot_config)
   2048 # You can put more passes here
-> 2050 compiled_fn = compiler_fn(flat_fn, fake_flat_tensor_args, aot_config)
   2052 if not hasattr(compiled_fn, '_boxed_call'):

File /opt/conda/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py:1305, in aot_wrapper_dedupe(flat_fn, flat_args, aot_config, compiler_fn)
   1304     if ok:
-> 1305         return compiler_fn(flat_fn, leaf_flat_args, aot_config)
   1307 # Strategy 2: Duplicate specialize.
   1308 #
   1309 # In Haskell types, suppose you have:
   (...)
   1341 #   }
   1342 #   keep_arg_mask = [True, True, False, True]

File /opt/conda/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py:1564, in aot_dispatch_autograd(flat_fn, flat_args, aot_config)
   1563     with track_graph_compiling(aot_config, "forward"):
-> 1564         compiled_fw_func = aot_config.fw_compiler(
   1565             fw_module, flat_args_with_views_handled
   1566         )
   1568 class CompiledFunction(torch.autograd.Function):

File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/utils.py:90, in dynamo_timed.<locals>.time_wrapper(*args, **kwargs)
     89 t0 = time.time()
---> 90 r = func(*args, **kwargs)
     91 latency = time.time() - t0

File /opt/conda/lib/python3.10/site-packages/torch/_inductor/compile_fx.py:373, in compile_fx.<locals>.fw_compiler(model, example_inputs)
    372 fixed = len(example_inputs) - num_example_inputs
--> 373 return inner_compile(
    374     model,
    375     example_inputs,
    376     num_fixed=fixed,
    377     cudagraphs=cudagraphs,
    378     graph_id=graph_id,
    379 )

File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/debug_utils.py:507, in wrap_compiler_debug.<locals>.debug_wrapper(gm, example_inputs, **kwargs)
    506 else:
--> 507     compiled_fn = compiler_fn(gm, example_inputs, **kwargs)
    509 return compiled_fn

File /opt/conda/lib/python3.10/site-packages/torch/_inductor/debug.py:223, in DebugContext.wrap.<locals>.inner(*args, **kwargs)
    222 with DebugContext():
--> 223     return fn(*args, **kwargs)

File /opt/conda/lib/python3.10/contextlib.py:79, in ContextDecorator.__call__.<locals>.inner(*args, **kwds)
     78 with self._recreate_cm():
---> 79     return func(*args, **kwds)

File /opt/conda/lib/python3.10/site-packages/torch/_inductor/compile_fx.py:139, in compile_fx_inner(gm, example_inputs, cudagraphs, num_fixed, is_backward, graph_id)
    138 with V.set_graph_handler(graph):
--> 139     graph.run(*example_inputs)
    140     compiled_fn = graph.compile_to_fn()

File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/utils.py:90, in dynamo_timed.<locals>.time_wrapper(*args, **kwargs)
     89 t0 = time.time()
---> 90 r = func(*args, **kwargs)
     91 latency = time.time() - t0

File /opt/conda/lib/python3.10/site-packages/torch/_inductor/graph.py:169, in GraphLowering.run(self, *args)
    167 @dynamo_utils.dynamo_timed
    168 def run(self, *args):
--> 169     return super().run(*args)

File /opt/conda/lib/python3.10/site-packages/torch/fx/interpreter.py:130, in Interpreter.run(self, initial_env, enable_io_processing, *args)
    129 try:
--> 130     self.env[node] = self.run_node(node)
    131 except Exception as e:

File /opt/conda/lib/python3.10/site-packages/torch/_inductor/graph.py:368, in GraphLowering.run_node(self, n)
    367 else:
--> 368     result = super().run_node(n)
    370 # Realize if (1) any user need inputs realized, or (2) there is
    371 # already too many reads and rematerializing can be bad.

File /opt/conda/lib/python3.10/site-packages/torch/fx/interpreter.py:171, in Interpreter.run_node(self, n)
    170 assert isinstance(kwargs, dict)
--> 171 return getattr(self, n.op)(n.target, args, kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/_inductor/graph.py:298, in GraphLowering.call_function(self, target, args, kwargs)
    297 except Exception as e:
--> 298     raise LoweringException(e, target, args, kwargs) from e

LoweringException: KeyError: 'BLOCK_K'
  target: aten.mm.default
  args[0]: TensorBox(StorageBox(
    ComputedBuffer(name='buf240', layout=FixedLayout('cuda', torch.float32, size=(512, 512), stride=[512, 1]), data=Pointwise(
      'cuda',
      torch.float32,
      tmp0 = load(buf239, i1 + 512 * i0)
      tmp1 = index_expr(49, torch.float32)
      tmp2 = tmp0 / tmp1
      return tmp2
      ,
      ranges=(512, 512),
      origins={view}
    ))
  ))
  args[1]: TensorBox(
    ReinterpretView(
      StorageBox(
        InputBuffer(name='primals_109', layout=FixedLayout('cuda', torch.float32, size=[1000, 512], stride=[512, 1]))
      ),
      FixedLayout('cuda', torch.float32, size=[512, 1000], stride=[1, 512]),
      no origins?
    )
  )

While executing %mm : [#users=1] = call_function[target=torch.ops.aten.mm.default](args = (%view, %permute), kwargs = {})
Original traceback:
Module stack: {'self_fc': "<class 'torch.nn.modules.linear.Linear'>"}
  File "/root/.cache/torch/hub/pytorch_vision_v0.10.0/torchvision/models/resnet.py", line 244, in _forward_impl
    x = self.fc(x)
 |   File "/root/.cache/torch/hub/pytorch_vision_v0.10.0/torchvision/models/resnet.py", line 249, in forward
    return self._forward_impl(x)

Gradient addition node due to multiple use of tensor around:
Module stack: {'self_conv1': "<class 'torch.nn.modules.conv.Conv2d'>"}
  File "/root/.cache/torch/hub/pytorch_vision_v0.10.0/torchvision/models/resnet.py", line 232, in _forward_impl
    x = self.conv1(x)
 |   File "/root/.cache/torch/hub/pytorch_vision_v0.10.0/torchvision/models/resnet.py", line 249, in forward
    return self._forward_impl(x)

The above exception was the direct cause of the following exception:

BackendCompilerFailed                     Traceback (most recent call last)
Cell In[5], line 11
      9 for i in range(10):
     10     begin = time.time()
---> 11     output = opt_model_max(x)
     12     time.sleep(1)
     13     end = time.time()

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1482, in Module._call_impl(self, *args, **kwargs)
   1477 # If we don't have any hooks, we want to skip the rest of the logic in
   1478 # this function, and just call forward.
   1479 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1480         or _global_backward_pre_hooks or _global_backward_hooks
   1481         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1482     return forward_call(*args, **kwargs)
   1483 # Do not call functions when jit is used
   1484 full_backward_hooks, non_full_backward_hooks = [], []

File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:83, in OptimizedModule.forward(self, *args, **kwargs)
     82 def forward(self, *args, **kwargs):
---> 83     return self.dynamo_ctx(self._orig_mod.forward)(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:212, in _TorchDynamoContext.__call__.<locals>._fn(*args, **kwargs)
    210 dynamic_ctx.__enter__()
    211 try:
--> 212     return fn(*args, **kwargs)
    213 finally:
    214     set_eval_frame(prior)

File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:333, in catch_errors_wrapper.<locals>.catch_errors(frame, cache_size)
    330             return hijacked_callback(frame, cache_size, hooks)
    332 with compile_lock:
--> 333     return callback(frame, cache_size, hooks)

File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:480, in convert_frame.<locals>._convert_frame(frame, cache_size, hooks)
    478 counters["frames"]["total"] += 1
    479 try:
--> 480     result = inner_convert(frame, cache_size, hooks)
    481     counters["frames"]["ok"] += 1
    482     return result

File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:103, in wrap_convert_context.<locals>._fn(*args, **kwargs)
    101 torch.fx.graph_module._forward_from_src = fx_forward_from_src_skip_result
    102 try:
--> 103     return fn(*args, **kwargs)
    104 finally:
    105     torch._C._set_grad_enabled(prior_grad_mode)

File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/utils.py:90, in dynamo_timed.<locals>.time_wrapper(*args, **kwargs)
     88     compilation_metrics[key] = []
     89 t0 = time.time()
---> 90 r = func(*args, **kwargs)
     91 latency = time.time() - t0
     92 # print(f"Dynamo timer: key={key}, latency={latency:.2f} sec")

File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:339, in convert_frame_assert.<locals>._convert_frame_assert(frame, cache_size, hooks)
    336 global initial_grad_state
    337 initial_grad_state = torch.is_grad_enabled()
--> 339 return _compile(
    340     frame.f_code,
    341     frame.f_globals,
    342     frame.f_locals,
    343     frame.f_builtins,
    344     compiler_fn,
    345     one_graph,
    346     export,
    347     hooks,
    348     frame,
    349 )

File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:400, in _compile(code, globals, locals, builtins, compiler_fn, one_graph, export, hooks, frame)
    398 for attempt in itertools.count():
    399     try:
--> 400         out_code = transform_code_object(code, transform)
    401         orig_code_map[out_code] = code
    402         break

File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py:341, in transform_code_object(code, transformations, safe)
    338 instructions = cleaned_instructions(code, safe)
    339 propagate_line_nums(instructions)
--> 341 transformations(instructions, code_options)
    343 fix_vars(instructions, code_options)
    345 dirty = True

File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:387, in _compile.<locals>.transform(instructions, code_options)
    374 nonlocal output
    375 tracer = InstructionTranslator(
    376     instructions,
    377     code,
   (...)
    385     mutated_closure_cell_contents,
    386 )
--> 387 tracer.run()
    388 output = tracer.output
    389 assert output is not None

File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:1684, in InstructionTranslator.run(self)
   1682 def run(self):
   1683     _step_logger()(logging.INFO, f"torchdynamo start tracing {self.f_code.co_name}")
-> 1684     super().run()

File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:538, in InstructionTranslatorBase.run(self)
    533 try:
    534     self.output.push_tx(self)
    535     while (
    536         self.instruction_pointer is not None
    537         and not self.output.should_exit
--> 538         and self.step()
    539     ):
    540         pass
    541 except BackendCompilerFailed:

File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:501, in InstructionTranslatorBase.step(self)
    499     if not hasattr(self, inst.opname):
    500         unimplemented(f"missing: {inst.opname}")
--> 501     getattr(self, inst.opname)(inst)
    503     return inst.opname != "RETURN_VALUE"
    504 except BackendCompilerFailed:

File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:1750, in InstructionTranslator.RETURN_VALUE(self, inst)
   1745 _step_logger()(
   1746     logging.INFO,
   1747     f"torchdynamo done tracing {self.f_code.co_name} (RETURN_VALUE)",
   1748 )
   1749 log.debug("RETURN_VALUE triggered compile")
-> 1750 self.output.compile_subgraph(self)
   1751 self.output.add_output_instructions([create_instruction("RETURN_VALUE")])

File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/output_graph.py:529, in OutputGraph.compile_subgraph(self, tx, partial_convert, reason)
    512     self.add_output_instructions(random_calls_instructions)
    514 if (
    515     stack_values
    516     and all(
   (...)
    526 
    527     # optimization to generate better code in a common case
    528     self.add_output_instructions(
--> 529         self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
    530         + [create_instruction("UNPACK_SEQUENCE", len(stack_values))]
    531     )
    532 else:
    533     graph_output_var = self.new_var("graph_out")

File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/output_graph.py:600, in OutputGraph.compile_and_call_fx_graph(self, tx, rv, root)
    598 assert_no_fake_params_or_buffers(gm)
    599 with tracing(self.tracing_context):
--> 600     compiled_fn = self.call_user_compiler(gm)
    601 compiled_fn = disable(compiled_fn)
    603 counters["stats"]["unique_graphs"] += 1

File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/output_graph.py:681, in OutputGraph.call_user_compiler(self, gm)
    679 except Exception as e:
    680     compiled_fn = gm.forward
--> 681     raise BackendCompilerFailed(self.compiler_fn, e) from e
    682 return compiled_fn

BackendCompilerFailed: debug_wrapper raised LoweringException: KeyError: 'BLOCK_K'
  target: aten.mm.default
  args[0]: TensorBox(StorageBox(
    ComputedBuffer(name='buf240', layout=FixedLayout('cuda', torch.float32, size=(512, 512), stride=[512, 1]), data=Pointwise(
      'cuda',
      torch.float32,
      tmp0 = load(buf239, i1 + 512 * i0)
      tmp1 = index_expr(49, torch.float32)
      tmp2 = tmp0 / tmp1
      return tmp2
      ,
      ranges=(512, 512),
      origins={view}
    ))
  ))
  args[1]: TensorBox(
    ReinterpretView(
      StorageBox(
        InputBuffer(name='primals_109', layout=FixedLayout('cuda', torch.float32, size=[1000, 512], stride=[512, 1]))
      ),
      FixedLayout('cuda', torch.float32, size=[512, 1000], stride=[1, 512]),
      no origins?
    )
  )

While executing %mm : [#users=1] = call_function[target=torch.ops.aten.mm.default](args = (%view, %permute), kwargs = {})
Original traceback:
Module stack: {'self_fc': "<class 'torch.nn.modules.linear.Linear'>"}
  File "/root/.cache/torch/hub/pytorch_vision_v0.10.0/torchvision/models/resnet.py", line 244, in _forward_impl
    x = self.fc(x)
 |   File "/root/.cache/torch/hub/pytorch_vision_v0.10.0/torchvision/models/resnet.py", line 249, in forward
    return self._forward_impl(x)

Gradient addition node due to multiple use of tensor around:
Module stack: {'self_conv1': "<class 'torch.nn.modules.conv.Conv2d'>"}
  File "/root/.cache/torch/hub/pytorch_vision_v0.10.0/torchvision/models/resnet.py", line 232, in _forward_impl
    x = self.conv1(x)
 |   File "/root/.cache/torch/hub/pytorch_vision_v0.10.0/torchvision/models/resnet.py", line 249, in forward
    return self._forward_impl(x)

Set torch._dynamo.config.verbose=True for more information

You can suppress this exception and fall back to eager by setting:
    torch._dynamo.config.suppress_errors = True
jansel commented 1 year ago

Should be fixed by https://github.com/pytorch/pytorch/pull/91575

ngimel commented 1 year ago

CLosing, #91575 is landed, reopen if still doesn't work