pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
82.15k stars 22.08k forks source link

Shape Error when training HF deberta-base with Inductor #96456

Closed Lokiiiiii closed 9 months ago

Lokiiiiii commented 1 year ago

๐Ÿ› Describe the bug

When using HuggingFace's Trainer API I noticed that PyTorch eager mode succeeds as expected but inductor fails with a shape mismatch error:

ValueError: Cannot view a tensor with shape torch.Size([1, 256, 12, 64]) and strides (196608, 64, 16384, 1) as a tensor with shape (1, 256, 768)!

This only happens with the deberta-base model

Error logs

โ•ญโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ Traceback (most recent call last) โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ•ฎ
โ”‚ /opt/conda/lib/python3.8/site-packages/torch/_dynamo/output_graph.py:670 in call_user_compiler   โ”‚
โ”‚                                                                                                  โ”‚
โ”‚   667 โ”‚   โ”‚   โ”‚   elif config.DO_NOT_USE_legacy_non_fake_example_inputs:                         โ”‚
โ”‚   668 โ”‚   โ”‚   โ”‚   โ”‚   compiled_fn = compiler_fn(gm, self.example_inputs())                       โ”‚
โ”‚   669 โ”‚   โ”‚   โ”‚   else:                                                                          โ”‚
โ”‚ โฑ 670 โ”‚   โ”‚   โ”‚   โ”‚   compiled_fn = compiler_fn(gm, self.fake_example_inputs())                  โ”‚
โ”‚   671 โ”‚   โ”‚   โ”‚   _step_logger()(logging.INFO, f"done compiler function {name}")                 โ”‚
โ”‚   672 โ”‚   โ”‚   โ”‚   assert callable(compiled_fn), "compiler_fn did not return callable"            โ”‚
โ”‚   673 โ”‚   โ”‚   except Exception as e:                                                             โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /opt/conda/lib/python3.8/site-packages/torch/_dynamo/debug_utils.py:1055 in debug_wrapper        โ”‚
โ”‚                                                                                                  โ”‚
โ”‚   1052 โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   )                                                                     โ”‚
โ”‚   1053 โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   raise                                                                 โ”‚
โ”‚   1054 โ”‚   โ”‚   else:                                                                             โ”‚
โ”‚ โฑ 1055 โ”‚   โ”‚   โ”‚   compiled_gm = compiler_fn(gm, example_inputs)                                 โ”‚
โ”‚   1056 โ”‚   โ”‚                                                                                     โ”‚
โ”‚   1057 โ”‚   โ”‚   return compiled_gm                                                                โ”‚
โ”‚   1058                                                                                           โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /opt/conda/lib/python3.8/site-packages/torch/__init__.py:1390 in __call__                        โ”‚
โ”‚                                                                                                  โ”‚
โ”‚   1387 โ”‚   def __call__(self, model_, inputs_):                                                  โ”‚
โ”‚   1388 โ”‚   โ”‚   from torch._inductor.compile_fx import compile_fx                                 โ”‚
โ”‚   1389 โ”‚   โ”‚                                                                                     โ”‚
โ”‚ โฑ 1390 โ”‚   โ”‚   return compile_fx(model_, inputs_, config_patches=self.config)                    โ”‚
โ”‚   1391                                                                                           โ”‚
โ”‚   1392                                                                                           โ”‚
โ”‚   1393 def compile(model: Optional[Callable] = None, *,                                          โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /opt/conda/lib/python3.8/site-packages/torch/_inductor/compile_fx.py:455 in compile_fx           โ”‚
โ”‚                                                                                                  โ”‚
โ”‚   452 โ”‚   โ”‚   # TODO: can add logging before/after the call to create_aot_dispatcher_function    โ”‚
โ”‚   453 โ”‚   โ”‚   # in torch._functorch/aot_autograd.py::aot_module_simplified::aot_function_simpl   โ”‚
โ”‚   454 โ”‚   โ”‚   # once torchdynamo is merged into pytorch                                          โ”‚
โ”‚ โฑ 455 โ”‚   โ”‚   return aot_autograd(                                                               โ”‚
โ”‚   456 โ”‚   โ”‚   โ”‚   fw_compiler=fw_compiler,                                                       โ”‚
โ”‚   457 โ”‚   โ”‚   โ”‚   bw_compiler=bw_compiler,                                                       โ”‚
โ”‚   458 โ”‚   โ”‚   โ”‚   decompositions=select_decomp_table(),                                          โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /opt/conda/lib/python3.8/site-packages/torch/_dynamo/backends/common.py:48 in compiler_fn        โ”‚
โ”‚                                                                                                  โ”‚
โ”‚    45 โ”‚   โ”‚   try:                                                                               โ”‚
โ”‚    46 โ”‚   โ”‚   โ”‚   # NB: NOT cloned!                                                              โ”‚
โ”‚    47 โ”‚   โ”‚   โ”‚   with enable_aot_logging():                                                     โ”‚
โ”‚ โฑ  48 โ”‚   โ”‚   โ”‚   โ”‚   cg = aot_module_simplified(gm, example_inputs, **kwargs)                   โ”‚
โ”‚    49 โ”‚   โ”‚   โ”‚   โ”‚   counters["aot_autograd"]["ok"] += 1                                        โ”‚
โ”‚    50 โ”‚   โ”‚   โ”‚   โ”‚   return eval_frame.disable(cg)                                              โ”‚
โ”‚    51 โ”‚   โ”‚   except Exception:                                                                  โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /opt/conda/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py:2805 in                  โ”‚
โ”‚ aot_module_simplified                                                                            โ”‚
โ”‚                                                                                                  โ”‚
โ”‚   2802 โ”‚   full_args.extend(params_flat)                                                         โ”‚
โ”‚   2803 โ”‚   full_args.extend(args)                                                                โ”‚
โ”‚   2804 โ”‚                                                                                         โ”‚
โ”‚ โฑ 2805 โ”‚   compiled_fn = create_aot_dispatcher_function(                                         โ”‚
โ”‚   2806 โ”‚   โ”‚   functional_call,                                                                  โ”‚
โ”‚   2807 โ”‚   โ”‚   full_args,                                                                        โ”‚
โ”‚   2808 โ”‚   โ”‚   aot_config,                                                                       โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /opt/conda/lib/python3.8/site-packages/torch/_dynamo/utils.py:163 in time_wrapper                โ”‚
โ”‚                                                                                                  โ”‚
โ”‚    160 โ”‚   โ”‚   โ”‚   if key not in compilation_metrics:                                            โ”‚
โ”‚    161 โ”‚   โ”‚   โ”‚   โ”‚   compilation_metrics[key] = []                                             โ”‚
โ”‚    162 โ”‚   โ”‚   โ”‚   t0 = time.time()                                                              โ”‚
โ”‚ โฑ  163 โ”‚   โ”‚   โ”‚   r = func(*args, **kwargs)                                                     โ”‚
โ”‚    164 โ”‚   โ”‚   โ”‚   time_spent = time.time() - t0                                                 โ”‚
โ”‚    165 โ”‚   โ”‚   โ”‚   # print(f"Dynamo timer: key={key}, latency={latency:.2f} sec")                โ”‚
โ”‚    166 โ”‚   โ”‚   โ”‚   compilation_metrics[key].append(time_spent)                                   โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /opt/conda/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py:2498 in                  โ”‚
โ”‚ create_aot_dispatcher_function                                                                   โ”‚
โ”‚                                                                                                  โ”‚
โ”‚   2495 โ”‚   โ”‚   compiler_fn = partial(aot_wrapper_dedupe, compiler_fn=compiler_fn)                โ”‚
โ”‚   2496 โ”‚   โ”‚   # You can put more passes here                                                    โ”‚
โ”‚   2497 โ”‚   โ”‚                                                                                     โ”‚
โ”‚ โฑ 2498 โ”‚   โ”‚   compiled_fn = compiler_fn(flat_fn, fake_flat_args, aot_config)                    โ”‚
โ”‚   2499 โ”‚   โ”‚                                                                                     โ”‚
โ”‚   2500 โ”‚   โ”‚   if not hasattr(compiled_fn, "_boxed_call"):                                       โ”‚
โ”‚   2501 โ”‚   โ”‚   โ”‚   compiled_fn = make_boxed_func(compiled_fn)                                    โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /opt/conda/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py:1713 in                  โ”‚
โ”‚ aot_wrapper_dedupe                                                                               โ”‚
โ”‚                                                                                                  โ”‚
โ”‚   1710 โ”‚   โ”‚   โ”‚   โ”‚   break                                                                     โ”‚
โ”‚   1711 โ”‚   โ”‚                                                                                     โ”‚
โ”‚   1712 โ”‚   โ”‚   if ok:                                                                            โ”‚
โ”‚ โฑ 1713 โ”‚   โ”‚   โ”‚   return compiler_fn(flat_fn, leaf_flat_args, aot_config)                       โ”‚
โ”‚   1714 โ”‚                                                                                         โ”‚
โ”‚   1715 โ”‚   # Strategy 2: Duplicate specialize.                                                   โ”‚
โ”‚   1716 โ”‚   #                                                                                     โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /opt/conda/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py:2087 in                  โ”‚
โ”‚ aot_dispatch_autograd                                                                            โ”‚
โ”‚                                                                                                  โ”‚
โ”‚   2084 โ”‚   if config.use_functionalize:                                                          โ”‚
โ”‚   2085 โ”‚   โ”‚   with enable_python_dispatcher():                                                  โ”‚
โ”‚   2086 โ”‚   โ”‚   โ”‚   flattened_joints, _ = pytree.tree_flatten(joint_inputs)                       โ”‚
โ”‚ โฑ 2087 โ”‚   โ”‚   โ”‚   fx_g = make_fx(joint_forward_backward, aot_config.decompositions)(            โ”‚
โ”‚   2088 โ”‚   โ”‚   โ”‚   โ”‚   *joint_inputs                                                             โ”‚
โ”‚   2089 โ”‚   โ”‚   โ”‚   )                                                                             โ”‚
โ”‚   2090                                                                                           โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /opt/conda/lib/python3.8/site-packages/torch/fx/experimental/proxy_tensor.py:714 in wrapped      โ”‚
โ”‚                                                                                                  โ”‚
โ”‚   711 โ”‚   โ”‚   # thus irrelevant to any external functional trace.                                โ”‚
โ”‚   712 โ”‚   โ”‚   with decompose(decomposition_table), fake_tensor_mode, python_dispatcher_mode, \   โ”‚
โ”‚   713 โ”‚   โ”‚   โ”‚    sym_mode, proxy_mode, disable_autocast_cache(), disable_proxy_modes_tracing   โ”‚
โ”‚ โฑ 714 โ”‚   โ”‚   โ”‚   t = dispatch_trace(wrap_key(func, args, fx_tracer), tracer=fx_tracer, concre   โ”‚
โ”‚   715 โ”‚   โ”‚                                                                                      โ”‚
โ”‚   716 โ”‚   โ”‚   # TODO: kind of a bad way to do it, should maybe figure out a better way           โ”‚
โ”‚   717 โ”‚   โ”‚   if tracing_mode == "symbolic":                                                     โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /opt/conda/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py:209 in _fn                    โ”‚
โ”‚                                                                                                  โ”‚
โ”‚   206 โ”‚   โ”‚   โ”‚   dynamic_ctx = enable_dynamic(self.dynamic)                                     โ”‚
โ”‚   207 โ”‚   โ”‚   โ”‚   dynamic_ctx.__enter__()                                                        โ”‚
โ”‚   208 โ”‚   โ”‚   โ”‚   try:                                                                           โ”‚
โ”‚ โฑ 209 โ”‚   โ”‚   โ”‚   โ”‚   return fn(*args, **kwargs)                                                 โ”‚
โ”‚   210 โ”‚   โ”‚   โ”‚   finally:                                                                       โ”‚
โ”‚   211 โ”‚   โ”‚   โ”‚   โ”‚   set_eval_frame(prior)                                                      โ”‚
โ”‚   212 โ”‚   โ”‚   โ”‚   โ”‚   dynamic_ctx.__exit__(None, None, None)                                     โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /opt/conda/lib/python3.8/site-packages/torch/fx/experimental/proxy_tensor.py:443 in              โ”‚
โ”‚ dispatch_trace                                                                                   โ”‚
โ”‚                                                                                                  โ”‚
โ”‚   440 โ”‚   โ”‚   tracer: Tracer,                                                                    โ”‚
โ”‚   441 โ”‚   โ”‚   concrete_args: Optional[Tuple[Any, ...]] = None,                                   โ”‚
โ”‚   442 ) -> GraphModule:                                                                          โ”‚
โ”‚ โฑ 443 โ”‚   graph = tracer.trace(root, concrete_args)                                              โ”‚
โ”‚   444 โ”‚   name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name   โ”‚
โ”‚   445 โ”‚   return GraphModule(tracer.root, graph, name)                                           โ”‚
โ”‚   446                                                                                            โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /opt/conda/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py:209 in _fn                    โ”‚
โ”‚                                                                                                  โ”‚
โ”‚   206 โ”‚   โ”‚   โ”‚   dynamic_ctx = enable_dynamic(self.dynamic)                                     โ”‚
โ”‚   207 โ”‚   โ”‚   โ”‚   dynamic_ctx.__enter__()                                                        โ”‚
โ”‚   208 โ”‚   โ”‚   โ”‚   try:                                                                           โ”‚
โ”‚ โฑ 209 โ”‚   โ”‚   โ”‚   โ”‚   return fn(*args, **kwargs)                                                 โ”‚
โ”‚   210 โ”‚   โ”‚   โ”‚   finally:                                                                       โ”‚
โ”‚   211 โ”‚   โ”‚   โ”‚   โ”‚   set_eval_frame(prior)                                                      โ”‚
โ”‚   212 โ”‚   โ”‚   โ”‚   โ”‚   dynamic_ctx.__exit__(None, None, None)                                     โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /opt/conda/lib/python3.8/site-packages/torch/fx/_symbolic_trace.py:778 in trace                  โ”‚
โ”‚                                                                                                  โ”‚
โ”‚    775 โ”‚   โ”‚   โ”‚   โ”‚   self.create_node(                                                         โ”‚
โ”‚    776 โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   "output",                                                             โ”‚
โ”‚    777 โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   "output",                                                             โ”‚
โ”‚ โฑ  778 โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   (self.create_arg(fn(*args)),),                                        โ”‚
โ”‚    779 โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   {},                                                                   โ”‚
โ”‚    780 โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   type_expr=fn.__annotations__.get("return", None),                     โ”‚
โ”‚    781 โ”‚   โ”‚   โ”‚   โ”‚   )                                                                         โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /opt/conda/lib/python3.8/site-packages/torch/fx/_symbolic_trace.py:652 in flatten_fn             โ”‚
โ”‚                                                                                                  โ”‚
โ”‚    649 โ”‚   โ”‚   โ”‚                                                                                 โ”‚
โ”‚    650 โ”‚   โ”‚   โ”‚   def flatten_fn(*args):                                                        โ”‚
โ”‚    651 โ”‚   โ”‚   โ”‚   โ”‚   tree_args = pytree.tree_unflatten(list(args), in_spec)                    โ”‚
โ”‚ โฑ  652 โ”‚   โ”‚   โ”‚   โ”‚   tree_out = root_fn(*tree_args)                                            โ”‚
โ”‚    653 โ”‚   โ”‚   โ”‚   โ”‚   out_args, out_spec = pytree.tree_flatten(tree_out)                        โ”‚
โ”‚    654 โ”‚   โ”‚   โ”‚   โ”‚   assert isinstance(self.graph._codegen, _PyTreeCodeGen)                    โ”‚
โ”‚    655 โ”‚   โ”‚   โ”‚   โ”‚   self.graph._codegen.pytree_info = (                                       โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /opt/conda/lib/python3.8/site-packages/torch/fx/experimental/proxy_tensor.py:459 in wrapped      โ”‚
โ”‚                                                                                                  โ”‚
โ”‚   456 โ”‚   โ”‚   with _pop_mode_temporarily():                                                      โ”‚
โ”‚   457 โ”‚   โ”‚   โ”‚   track_tensor_tree(flat_tensors, flat_proxies, constant=None, tracer=tracer)    โ”‚
โ”‚   458 โ”‚   โ”‚                                                                                      โ”‚
โ”‚ โฑ 459 โ”‚   โ”‚   out = f(*tensors)                                                                  โ”‚
โ”‚   460 โ”‚   โ”‚   out = pytree.tree_map_only(                                                        โ”‚
โ”‚   461 โ”‚   โ”‚   โ”‚   torch.Tensor,                                                                  โ”‚
โ”‚   462 โ”‚   โ”‚   โ”‚   lambda t: get_proxy_slot(t, tracer, t, lambda x: x.proxy),                     โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /opt/conda/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py:1156 in traced_joint     โ”‚
โ”‚                                                                                                  โ”‚
โ”‚   1153 โ”‚   # the joint needs have args named "primals" and "tangents",                           โ”‚
โ”‚   1154 โ”‚   # which are hardcoded into the partitioning logic.                                    โ”‚
โ”‚   1155 โ”‚   def traced_joint(primals, tangents):                                                  โ”‚
โ”‚ โฑ 1156 โ”‚   โ”‚   return functionalized_f_helper(primals, tangents)                                 โ”‚
โ”‚   1157 โ”‚                                                                                         โ”‚
โ”‚   1158 โ”‚   def traced_forward(*primals):                                                         โ”‚
โ”‚   1159 โ”‚   โ”‚   return functionalized_f_helper(primals)                                           โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /opt/conda/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py:1108 in                  โ”‚
โ”‚ functionalized_f_helper                                                                          โ”‚
โ”‚                                                                                                  โ”‚
โ”‚   1105 โ”‚   โ”‚   torch._enable_functionalization(reapply_views=True)                               โ”‚
โ”‚   1106 โ”‚   โ”‚   try:                                                                              โ”‚
โ”‚   1107 โ”‚   โ”‚   โ”‚   # Run the joint                                                               โ”‚
โ”‚ โฑ 1108 โ”‚   โ”‚   โ”‚   f_outs = flat_fn_no_input_mutations(fn, f_primals, f_tangents, meta, keep_in  โ”‚
โ”‚   1109 โ”‚   โ”‚   finally:                                                                          โ”‚
โ”‚   1110 โ”‚   โ”‚   โ”‚   torch._disable_functionalization()                                            โ”‚
โ”‚   1111                                                                                           โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /opt/conda/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py:1076 in                  โ”‚
โ”‚ flat_fn_no_input_mutations                                                                       โ”‚
โ”‚                                                                                                  โ”‚
โ”‚   1073 โ”‚   โ”‚   ]                                                                                 โ”‚
โ”‚   1074 โ”‚   else:                                                                                 โ”‚
โ”‚   1075 โ”‚   โ”‚   primals_after_cloning = primals                                                   โ”‚
โ”‚ โฑ 1076 โ”‚   outs = flat_fn_with_synthetic_bases_expanded(fn, primals, primals_after_cloning, may  โ”‚
โ”‚   1077 โ”‚   return outs                                                                           โ”‚
โ”‚   1078                                                                                           โ”‚
โ”‚   1079 # This creates the final function that we want to trace using make_fx(),                  โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /opt/conda/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py:1048 in                  โ”‚
โ”‚ flat_fn_with_synthetic_bases_expanded                                                            โ”‚
โ”‚                                                                                                  โ”‚
โ”‚   1045 โ”‚   # *after* we clone inputs for autograd (see below), to preserve the view relationshi  โ”‚
โ”‚   1046 โ”‚   primals = unpack_synthetic_bases(primals_after_cloning, meta.synthetic_base_info)     โ”‚
โ”‚   1047 โ”‚   assert len(meta.fw_metadata.input_info) == len(primals)                               โ”‚
โ”‚ โฑ 1048 โ”‚   outs = forward_or_joint(fn, primals_before_cloning, primals, maybe_tangents, meta, k  โ”‚
โ”‚   1049 โ”‚   return outs                                                                           โ”‚
โ”‚   1050                                                                                           โ”‚
โ”‚   1051 # This function adds extra clone() calls on any inputs in the forward that get mutated.   โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /opt/conda/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py:1017 in forward_or_joint โ”‚
โ”‚                                                                                                  โ”‚
โ”‚   1014 โ”‚   # Call the backwards pass                                                             โ”‚
โ”‚   1015 โ”‚   if grad_primals:                                                                      โ”‚
โ”‚   1016 โ”‚   โ”‚   with fx_traceback.preserve_node_meta():                                           โ”‚
โ”‚ โฑ 1017 โ”‚   โ”‚   โ”‚   backward_out = torch.autograd.grad(                                           โ”‚
โ”‚   1018 โ”‚   โ”‚   โ”‚   โ”‚   needed_outs,                                                              โ”‚
โ”‚   1019 โ”‚   โ”‚   โ”‚   โ”‚   grad_primals,                                                             โ”‚
โ”‚   1020 โ”‚   โ”‚   โ”‚   โ”‚   grad_outputs=needed_tangents,                                             โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /opt/conda/lib/python3.8/site-packages/torch/autograd/__init__.py:269 in grad                    โ”‚
โ”‚                                                                                                  โ”‚
โ”‚   266 โ”‚   t_inputs = cast(Tuple[torch.Tensor, ...], (inputs,) if is_tensor_like(inputs) else t   โ”‚
โ”‚   267 โ”‚   overridable_args = t_outputs + t_inputs                                                โ”‚
โ”‚   268 โ”‚   if has_torch_function(overridable_args):                                               โ”‚
โ”‚ โฑ 269 โ”‚   โ”‚   return handle_torch_function(                                                      โ”‚
โ”‚   270 โ”‚   โ”‚   โ”‚   grad,                                                                          โ”‚
โ”‚   271 โ”‚   โ”‚   โ”‚   overridable_args,                                                              โ”‚
โ”‚   272 โ”‚   โ”‚   โ”‚   t_outputs,                                                                     โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /opt/conda/lib/python3.8/site-packages/torch/overrides.py:1534 in handle_torch_function          โ”‚
โ”‚                                                                                                  โ”‚
โ”‚   1531 โ”‚   โ”‚   # if we're here, the mode must be set to a TorchFunctionStackMode                 โ”‚
โ”‚   1532 โ”‚   โ”‚   # this unsets it and calls directly into TorchFunctionStackMode's torch function  โ”‚
โ”‚   1533 โ”‚   โ”‚   with _pop_mode_temporarily() as mode:                                             โ”‚
โ”‚ โฑ 1534 โ”‚   โ”‚   โ”‚   result = mode.__torch_function__(public_api, types, args, kwargs)             โ”‚
โ”‚   1535 โ”‚   โ”‚   if result is not NotImplemented:                                                  โ”‚
โ”‚   1536 โ”‚   โ”‚   โ”‚   return result                                                                 โ”‚
โ”‚   1537                                                                                           โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /opt/conda/lib/python3.8/site-packages/torch/_inductor/overrides.py:38 in __torch_function__     โ”‚
โ”‚                                                                                                  โ”‚
โ”‚    35 โ”‚   โ”‚   โ”‚   and replacements[func] in replacements_using_triton_random                     โ”‚
โ”‚    36 โ”‚   โ”‚   ):                                                                                 โ”‚
โ”‚    37 โ”‚   โ”‚   โ”‚   return replacements[func](*args, **kwargs)                                     โ”‚
โ”‚ โฑ  38 โ”‚   โ”‚   return func(*args, **kwargs)                                                       โ”‚
โ”‚    39                                                                                            โ”‚
โ”‚    40                                                                                            โ”‚
โ”‚    41 patch_functions = AutogradMonkeypatch                                                      โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /opt/conda/lib/python3.8/site-packages/torch/autograd/__init__.py:303 in grad                    โ”‚
โ”‚                                                                                                  โ”‚
โ”‚   300 โ”‚   โ”‚   โ”‚   โ”‚   allow_unused, accumulate_grad=False)  # Calls into the C++ engine to run   โ”‚
โ”‚   301 โ”‚   โ”‚   return _vmap_internals._vmap(vjp, 0, 0, allow_none_pass_through=True)(grad_outpu   โ”‚
โ”‚   302 โ”‚   else:                                                                                  โ”‚
โ”‚ โฑ 303 โ”‚   โ”‚   return Variable._execution_engine.run_backward(  # Calls into the C++ engine to    โ”‚
โ”‚   304 โ”‚   โ”‚   โ”‚   t_outputs, grad_outputs_, retain_graph, create_graph, t_inputs,                โ”‚
โ”‚   305 โ”‚   โ”‚   โ”‚   allow_unused, accumulate_grad=False)  # Calls into the C++ engine to run the   โ”‚
โ”‚   306                                                                                            โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /opt/conda/lib/python3.8/site-packages/torch/utils/_stats.py:20 in wrapper                       โ”‚
โ”‚                                                                                                  โ”‚
โ”‚   17 โ”‚   โ”‚   if fn.__qualname__ not in simple_call_counter:                                      โ”‚
โ”‚   18 โ”‚   โ”‚   โ”‚   simple_call_counter[fn.__qualname__] = 0                                        โ”‚
โ”‚   19 โ”‚   โ”‚   simple_call_counter[fn.__qualname__] = simple_call_counter[fn.__qualname__] + 1     โ”‚
โ”‚ โฑ 20 โ”‚   โ”‚   return fn(*args, **kwargs)                                                          โ”‚
โ”‚   21 โ”‚   return wrapper                                                                          โ”‚
โ”‚   22                                                                                             โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /opt/conda/lib/python3.8/site-packages/torch/fx/experimental/proxy_tensor.py:487 in              โ”‚
โ”‚ __torch_dispatch__                                                                               โ”‚
โ”‚                                                                                                  โ”‚
โ”‚   484 โ”‚   @count                                                                                 โ”‚
โ”‚   485 โ”‚   def __torch_dispatch__(self, func, types, args=(), kwargs=None):                       โ”‚
โ”‚   486 โ”‚   โ”‚   with self.sym_mode.enable(False):                                                  โ”‚
โ”‚ โฑ 487 โ”‚   โ”‚   โ”‚   return self.inner_torch_dispatch(func, types, args, kwargs)                    โ”‚
โ”‚   488 โ”‚                                                                                          โ”‚
โ”‚   489 โ”‚   def __enter__(self):                                                                   โ”‚
โ”‚   490 โ”‚   โ”‚   # sym mode first, then us...                                                       โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /opt/conda/lib/python3.8/site-packages/torch/fx/experimental/proxy_tensor.py:512 in              โ”‚
โ”‚ inner_torch_dispatch                                                                             โ”‚
โ”‚                                                                                                  โ”‚
โ”‚   509 โ”‚   โ”‚   if func in [prim.device.default]:                                                  โ”‚
โ”‚   510 โ”‚   โ”‚   โ”‚   return func(*args, **kwargs)                                                   โ”‚
โ”‚   511 โ”‚   โ”‚                                                                                      โ”‚
โ”‚ โฑ 512 โ”‚   โ”‚   out = proxy_call(self, func, args, kwargs)                                         โ”‚
โ”‚   513 โ”‚   โ”‚   return out                                                                         โ”‚
โ”‚   514                                                                                            โ”‚
โ”‚   515                                                                                            โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /opt/conda/lib/python3.8/site-packages/torch/fx/experimental/proxy_tensor.py:345 in proxy_call   โ”‚
โ”‚                                                                                                  โ”‚
โ”‚   342 โ”‚   โ”‚   else:                                                                              โ”‚
โ”‚   343 โ”‚   โ”‚   โ”‚   args[0].proxy = proxy_out                                                      โ”‚
โ”‚   344 โ”‚                                                                                          โ”‚
โ”‚ โฑ 345 โ”‚   out = func(*args, **kwargs)                                                            โ”‚
โ”‚   346 โ”‚                                                                                          โ”‚
โ”‚   347 โ”‚   # In some circumstances, we will be tracing in a situation where a tensor              โ”‚
โ”‚   348 โ”‚   # is *statically* known to be a constant (currently, this only happens if              โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /opt/conda/lib/python3.8/site-packages/torch/_ops.py:284 in __call__                             โ”‚
โ”‚                                                                                                  โ”‚
โ”‚   281 โ”‚   โ”‚   )                                                                                  โ”‚
โ”‚   282 โ”‚                                                                                          โ”‚
โ”‚   283 โ”‚   def __call__(self, *args, **kwargs):                                                   โ”‚
โ”‚ โฑ 284 โ”‚   โ”‚   return self._op(*args, **kwargs or {})                                             โ”‚
โ”‚   285 โ”‚                                                                                          โ”‚
โ”‚   286 โ”‚   def __hash__(self):                                                                    โ”‚
โ”‚   287 โ”‚   โ”‚   return hash(self._op)                                                              โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /opt/conda/lib/python3.8/site-packages/torch/utils/_stats.py:20 in wrapper                       โ”‚
โ”‚                                                                                                  โ”‚
โ”‚   17 โ”‚   โ”‚   if fn.__qualname__ not in simple_call_counter:                                      โ”‚
โ”‚   18 โ”‚   โ”‚   โ”‚   simple_call_counter[fn.__qualname__] = 0                                        โ”‚
โ”‚   19 โ”‚   โ”‚   simple_call_counter[fn.__qualname__] = simple_call_counter[fn.__qualname__] + 1     โ”‚
โ”‚ โฑ 20 โ”‚   โ”‚   return fn(*args, **kwargs)                                                          โ”‚
โ”‚   21 โ”‚   return wrapper                                                                          โ”‚
โ”‚   22                                                                                             โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /opt/conda/lib/python3.8/site-packages/torch/_subclasses/fake_tensor.py:987 in                   โ”‚
โ”‚ __torch_dispatch__                                                                               โ”‚
โ”‚                                                                                                  โ”‚
โ”‚    984 โ”‚   @count                                                                                โ”‚
โ”‚    985 โ”‚   def __torch_dispatch__(self, func, types, args=(), kwargs=None):                      โ”‚
โ”‚    986 โ”‚   โ”‚   try:                                                                              โ”‚
โ”‚ โฑ  987 โ”‚   โ”‚   โ”‚   return self.dispatch(func, types, args, kwargs)                               โ”‚
โ”‚    988 โ”‚   โ”‚   except TypeError:                                                                 โ”‚
โ”‚    989 โ”‚   โ”‚   โ”‚   log.exception("fake tensor raised TypeError")                                 โ”‚
โ”‚    990 โ”‚   โ”‚   โ”‚   raise                                                                         โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /opt/conda/lib/python3.8/site-packages/torch/_subclasses/fake_tensor.py:1170 in dispatch         โ”‚
โ”‚                                                                                                  โ”‚
โ”‚   1167 โ”‚   โ”‚   # python meta registrations, prims, decomps, and c++ meta fns (structured kernel  โ”‚
โ”‚   1168 โ”‚   โ”‚   try:                                                                              โ”‚
โ”‚   1169 โ”‚   โ”‚   โ”‚   with in_kernel_invocation_manager(self):                                      โ”‚
โ”‚ โฑ 1170 โ”‚   โ”‚   โ”‚   โ”‚   r = func(*args, **kwargs)                                                 โ”‚
โ”‚   1171 โ”‚   โ”‚   except NotImplementedError as not_implemented_error:                              โ”‚
โ”‚   1172 โ”‚   โ”‚   โ”‚   # no meta kernel registered, fallback to kernel for the device                โ”‚
โ”‚   1173 โ”‚   โ”‚   โ”‚   if not self.allow_fallback_kernels:                                           โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /opt/conda/lib/python3.8/site-packages/torch/_ops.py:284 in __call__                             โ”‚
โ”‚                                                                                                  โ”‚
โ”‚   281 โ”‚   โ”‚   )                                                                                  โ”‚
โ”‚   282 โ”‚                                                                                          โ”‚
โ”‚   283 โ”‚   def __call__(self, *args, **kwargs):                                                   โ”‚
โ”‚ โฑ 284 โ”‚   โ”‚   return self._op(*args, **kwargs or {})                                             โ”‚
โ”‚   285 โ”‚                                                                                          โ”‚
โ”‚   286 โ”‚   def __hash__(self):                                                                    โ”‚
โ”‚   287 โ”‚   โ”‚   return hash(self._op)                                                              โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /opt/conda/lib/python3.8/site-packages/torch/_refs/__init__.py:3988 in view                      โ”‚
โ”‚                                                                                                  โ”‚
โ”‚   3985 # TODO: Turn this into a decomposition (currently fails on reshape meta tests)            โ”‚
โ”‚   3986 @register_decomposition(aten.view)                                                        โ”‚
โ”‚   3987 def view(a: TensorLikeType, *shape: ShapeType) -> TensorLikeType:                         โ”‚
โ”‚ โฑ 3988 โ”‚   return _reshape_view_helper(a, *shape, allow_copy=False)                              โ”‚
โ”‚   3989                                                                                           โ”‚
โ”‚   3990                                                                                           โ”‚
โ”‚   3991 # CompositeImplicitAutograd - don't register decomp                                       โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /opt/conda/lib/python3.8/site-packages/torch/_refs/__init__.py:3237 in _reshape_view_helper      โ”‚
โ”‚                                                                                                  โ”‚
โ”‚   3234 โ”‚   โ”‚   โ”‚   โ”‚   msg = "Cannot view a tensor with shape {0} and strides {1} as a tensor w  โ”‚
โ”‚   3235 โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   a.shape, a.stride(), shape                                            โ”‚
โ”‚   3236 โ”‚   โ”‚   โ”‚   โ”‚   )                                                                         โ”‚
โ”‚ โฑ 3237 โ”‚   โ”‚   โ”‚   โ”‚   raise ValueError(msg)                                                     โ”‚
โ”‚   3238 โ”‚   โ”‚   โ”‚                                                                                 โ”‚
โ”‚   3239 โ”‚   โ”‚   โ”‚   a_ = flatten(a_, idx, end)                                                    โ”‚
โ”‚   3240                                                                                           โ”‚
โ•ฐโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ•ฏ
ValueError: Cannot view a tensor with shape torch.Size([1, 256, 12, 64]) and strides (196608, 64, 16384, 1) as a tensor with shape (1, 256, 768)!

Minified repro

Minifier was unable to repro the error

pip3 install numpy --pre torch --force-reinstall --index-url https://download.pytorch.org/whl/nightly/cu117
git clone https://github.com/huggingface/transformers.git
cd transformers
pip install -e .

cd examples/pytorch/language-modeling
pip install -r requirements.txt
WANDB_DISABLED=true python run_mlm.py --model_name_or_path microsoft/deberta-base --output_dir . --fp16 --dataloader_drop_last --dataset_config_name wikitext-2-raw-v1 --dataset_name wikitext --do_train --evaluation_strategy no --logging_strategy epoch --max_seq_length 128 --num_train_epochs 1 --overwrite_output_dir --per_device_train_batch_size 128 --save_strategy no --torch_compile_backend inductor

Versions

Collecting environment information...
PyTorch version: 2.0.0a0+git9cfa076
Is debug build: False
CUDA used to build PyTorch: 11.7
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.5 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
Clang version: Could not collect
CMake version: version 3.25.2
Libc version: glibc-2.31

Python version: 3.8.16 | packaged by conda-forge | (default, Feb  1 2023, 16:01:55)  [GCC 11.3.0] (64-bit runtime)
Python platform: Linux-5.15.0-1028-aws-x86_64-with-glibc2.10
Is CUDA available: True
CUDA runtime version: 11.7.99
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: 
GPU 0: NVIDIA A10G
GPU 1: NVIDIA A10G
GPU 2: NVIDIA A10G
GPU 3: NVIDIA A10G
GPU 4: NVIDIA A10G
GPU 5: NVIDIA A10G
GPU 6: NVIDIA A10G

Nvidia driver version: 515.65.01
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.5.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.5.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.5.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.5.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.5.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.5.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.5.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                    x86_64
CPU op-mode(s):                  32-bit, 64-bit
Byte Order:                      Little Endian
Address sizes:                   48 bits physical, 48 bits virtual
CPU(s):                          192
On-line CPU(s) list:             0-191
Thread(s) per core:              2
Core(s) per socket:              48
Socket(s):                       2
NUMA node(s):                    2
Vendor ID:                       AuthenticAMD
CPU family:                      23
Model:                           49
Model name:                      AMD EPYC 7R32
Stepping:                        0
CPU MHz:                         2799.534
BogoMIPS:                        5599.06
Hypervisor vendor:               KVM
Virtualization type:             full
L1d cache:                       3 MiB
L1i cache:                       3 MiB
L2 cache:                        48 MiB
L3 cache:                        384 MiB
NUMA node0 CPU(s):               0-47,96-143
NUMA node1 CPU(s):               48-95,144-191
Vulnerability Itlb multihit:     Not affected
Vulnerability L1tf:              Not affected
Vulnerability Mds:               Not affected
Vulnerability Meltdown:          Not affected
Vulnerability Mmio stale data:   Not affected
Vulnerability Retbleed:          Mitigation; untrained return thunk; SMT enabled with STIBP protection
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:        Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:        Mitigation; Retpolines, IBPB conditional, STIBP always-on, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds:             Not affected
Vulnerability Tsx async abort:   Not affected
Flags:                           fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf tsc_known_freq pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy cr8_legacy abm sse4a misalignsse 3dnowprefetch topoext perfctr_core ssbd ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 clzero xsaveerptr rdpru wbnoinvd arat npt nrip_save rdpid

Versions of relevant libraries:
[pip3] clip-anytorch==2.5.2
[pip3] CoCa-pytorch==0.0.7
[pip3] dalle2-pytorch==1.10.5
[pip3] ema-pytorch==0.2.1
[pip3] functorch==1.14.0a0+408bcf1
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.23.5
[pip3] pytorch-transformers==1.2.0
[pip3] pytorch-warmup==0.1.1
[pip3] rotary-embedding-torch==0.2.1
[pip3] sagemaker-pytorch-training==2.7.0
[pip3] torch==2.0.0a0+git9cfa076
[pip3] torch-fidelity==0.3.0
[pip3] torch-struct==0.5
[pip3] torchaudio==2.0.0a0+b96a7eb
[pip3] torchdata==0.5.1+a246b31
[pip3] torchmetrics==0.11.3
[pip3] torchrec-nightly==2023.3.6
[pip3] torchtext==0.14.0a0+5b78d07
[pip3] torchvision==0.14.1a0+b69fce3
[pip3] vector-quantize-pytorch==1.1.1
[conda] clip-anytorch             2.5.2                    pypi_0    pypi
[conda] coca-pytorch              0.0.7                    pypi_0    pypi
[conda] dalle2-pytorch            1.10.5                   pypi_0    pypi
[conda] ema-pytorch               0.2.1                    pypi_0    pypi
[conda] functorch                 1.14.0a0+408bcf1          pypi_0    pypi
[conda] magma-cuda117             2.6.1                         1    pytorch
[conda] mkl                       2022.2.1         h84fe81f_16997    conda-forge
[conda] mkl-include               2023.0.0         h84fe81f_26648    conda-forge
[conda] numpy                     1.21.2                   pypi_0    pypi
[conda] pytorch                   1.13.1          cpu_py38hbac4b8a_1    conda-forge
[conda] pytorch-transformers      1.2.0                    pypi_0    pypi
[conda] pytorch-warmup            0.1.1                    pypi_0    pypi
[conda] rotary-embedding-torch    0.2.1                    pypi_0    pypi
[conda] sagemaker-pytorch-training 2.7.0                    pypi_0    pypi
[conda] torch                     2.0.0a0+git9cfa076          pypi_0    pypi
[conda] torch-fidelity            0.3.0                    pypi_0    pypi
[conda] torch-struct              0.5                      pypi_0    pypi
[conda] torchaudio                2.0.0a0+b96a7eb          pypi_0    pypi
[conda] torchdata                 0.5.1            py38h60d003c_1    conda-forge
[conda] torchmetrics              0.11.3                   pypi_0    pypi
[conda] torchrec-nightly          2023.3.6                 pypi_0    pypi
[conda] torchtext                 0.14.0a0+5b78d07          pypi_0    pypi
[conda] torchvision               0.15.0a0+0bdd01a          pypi_0    pypi
[conda] vector-quantize-pytorch   1.1.1                    pypi_0    pypi

cc @ezyang @eellison @bdhirsh @msaroufim @wconstab @anijain2305 @zou3519 @ngimel @soumith

ezyang commented 1 year ago

Looks like a stride propagation error.

cc @dagitses for stride agnostic pytorch

davidberard98 commented 1 year ago

managed to get this more minimal repro, haven't looked much at it yet. (note - if you're trying to repro the original transformers issue, you need to run with a single gpu or else you'll run into some other faketensor issue)

import torch

x = torch.rand((1, 12, 256*64), requires_grad=True)

def transpose_for_scores(x):
    new_x_shape = x.size()[:-1] + (256, -1)
    x = x.view(new_x_shape)
    return x.permute(0, 2, 1, 3)

def fn(x):
    scale_factor = 0.5
    x = x.relu()
    x = transpose_for_scores(x)
    x /= torch.sqrt(torch.tensor(x.size(-1), dtype=torch.float) * scale_factor)
    return x.transpose(-1, -2)

fn(x)
torch.compile(fn)(x)
eellison commented 1 year ago

Hmm neither CrossRefFakeMode nor DebugInterpreter catch this.

anijain2305 commented 1 year ago

Even aot_eager fails here.

import torch

x = torch.rand((1, 12, 256*64), requires_grad=True)

def transpose_for_scores(x):
    new_x_shape = x.size()[:-1] + (256, -1)
    x = x.view(new_x_shape)
    return x.permute(0, 2, 1, 3)

def fn(x):
    scale_factor = 0.5
    x = x.relu()
    x = transpose_for_scores(x)
    x /= torch.sqrt(torch.tensor(x.size(-1), dtype=torch.float) * scale_factor)
    return x.transpose(-1, -2)

fn(x)
torch.compile(fn, backend="aot_eager")(x)
anijain2305 commented 1 year ago

cc @ezyang @bdhirsh to advise.

ngimel commented 1 year ago

THe minimum repros throw different error ("one of the variables needed for gradient computation has been modified by an inplace operation"). THe original view error is probably due to copy_ decomposition producing wrong strides, @bdhirsh has a fix for this that is blocked by cpp codegen in fbcode

bdhirsh commented 1 year ago

That looks like something that should be fixed by this PR https://github.com/pytorch/pytorch/issues/96456#issuecomment-1562284376. I can't test it at the moment (allocation was nuked) but I can try to confirm later.

bdhirsh commented 1 year ago

Unfortunately even with the copy() decomp fix in inductor, the repro now gives this error for me:

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [1, 12, 16384]], which is output 0 of ReluBackward0, is at version 1; expected version 0 instead.
bdhirsh commented 1 year ago

Actually, I realized that the small repro above is broken (that error also shows up if you run in eager, and actually call .backward()).

bdhirsh commented 1 year ago

I tried running the HuggingFace repro. On my 40gb machine, I get an OOM - it would be great if someone can patch this PR locally and try to repro! https://github.com/pytorch/pytorch/issues/96456.

ezyang commented 1 year ago

@davidberard98's repro still fails for me in AOTAutograd https://github.com/pytorch/pytorch/issues/96456#issuecomment-1467355129

  File "/data/users/ezyang/b/pytorch/torch/fx/experimental/proxy_tensor.py", line 532, in __torch_dispatch__
    return self.inner_torch_dispatch(func, types, args, kwargs)
  File "/data/users/ezyang/b/pytorch/torch/fx/experimental/proxy_tensor.py", line 557, in inner_torch_dispatch
    return proxy_call(self, func, self.pre_dispatch, args, kwargs)
  File "/data/users/ezyang/b/pytorch/torch/fx/experimental/proxy_tensor.py", line 367, in proxy_call
    out = func(*args, **kwargs)
  File "/data/users/ezyang/b/pytorch/torch/_ops.py", line 429, in __call__
    return self._op(*args, **kwargs or {})
  File "/data/users/ezyang/b/pytorch/torch/utils/_stats.py", line 20, in wrapper
    return fn(*args, **kwargs)
  File "/data/users/ezyang/b/pytorch/torch/_subclasses/fake_tensor.py", line 1160, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
  File "/data/users/ezyang/b/pytorch/torch/_subclasses/fake_tensor.py", line 1404, in dispatch
    r = func(*args, **kwargs)
  File "/data/users/ezyang/b/pytorch/torch/_ops.py", line 429, in __call__
    return self._op(*args, **kwargs or {})
  File "/data/users/ezyang/b/pytorch/torch/_refs/__init__.py", line 4138, in view
    return _reshape_view_helper(a, *shape, allow_copy=False)
  File "/data/users/ezyang/b/pytorch/torch/_refs/__init__.py", line 3352, in _reshape_view_helper
    raise ValueError(msg)
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
ValueError: Cannot view a tensor with shape torch.Size([1, 12, 256, 64]) and strides (196608, 64, 768, 1) as a tensor with shape (1, 12, 16384)!

You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True
williamwen42 commented 9 months ago

I get a different error when I try to run @davidberard98's repro today:

/data/users/williamwen/pytorch/torch/autograd/__init__.py:411: UserWarning: Error detected in ReluBackward0. Traceback of forward call that caused the error:
  File "/data/users/williamwen/pytorch/playground5.py", line 12, in fn
    x = x.relu()
 (Triggered internally at /data/users/williamwen/pytorch/torch/csrc/autograd/python_anomaly_mode.cpp:113.)
  result = Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
Traceback (most recent call last):
  File "/data/users/williamwen/pytorch/playground5.py", line 18, in <module>
    torch.compile(fn)(x)
  File "/data/users/williamwen/pytorch/torch/_dynamo/eval_frame.py", line 489, in _fn
    return fn(*args, **kwargs)
  File "/data/users/williamwen/pytorch/torch/_dynamo/eval_frame.py", line 655, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state)
  File "/data/users/williamwen/pytorch/torch/_dynamo/convert_frame.py", line 721, in _convert_frame
    result = inner_convert(frame, cache_entry, hooks, frame_state)
  File "/data/users/williamwen/pytorch/torch/_dynamo/convert_frame.py", line 383, in _convert_frame_assert
    compiled_product = _compile(
  File "/data/users/williamwen/pytorch/torch/_dynamo/convert_frame.py", line 645, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/data/users/williamwen/pytorch/torch/_dynamo/utils.py", line 244, in time_wrapper
    r = func(*args, **kwargs)
  File "/data/users/williamwen/pytorch/torch/_dynamo/convert_frame.py", line 562, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/data/users/williamwen/pytorch/torch/_dynamo/bytecode_transformation.py", line 1033, in transform_code_object
    transformations(instructions, code_options)
  File "/data/users/williamwen/pytorch/torch/_dynamo/convert_frame.py", line 151, in _fn
    return fn(*args, **kwargs)
  File "/data/users/williamwen/pytorch/torch/_dynamo/convert_frame.py", line 527, in transform
    tracer.run()
  File "/data/users/williamwen/pytorch/torch/_dynamo/symbolic_convert.py", line 2123, in run
    super().run()
  File "/data/users/williamwen/pytorch/torch/_dynamo/symbolic_convert.py", line 818, in run
    and self.step()
  File "/data/users/williamwen/pytorch/torch/_dynamo/symbolic_convert.py", line 781, in step
    getattr(self, inst.opname)(inst)
  File "/data/users/williamwen/pytorch/torch/_dynamo/symbolic_convert.py", line 2238, in RETURN_VALUE
    self.output.compile_subgraph(
  File "/data/users/williamwen/pytorch/torch/_dynamo/output_graph.py", line 912, in compile_subgraph
    self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
  File "/data/users/williamwen/py310-env/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/data/users/williamwen/pytorch/torch/_dynamo/output_graph.py", line 1080, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
  File "/data/users/williamwen/pytorch/torch/_dynamo/utils.py", line 244, in time_wrapper
    r = func(*args, **kwargs)
  File "/data/users/williamwen/pytorch/torch/_dynamo/output_graph.py", line 1152, in call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
  File "/data/users/williamwen/pytorch/torch/_dynamo/output_graph.py", line 1133, in call_user_compiler
    compiled_fn = compiler_fn(gm, self.example_inputs())
  File "/data/users/williamwen/pytorch/torch/_dynamo/repro/after_dynamo.py", line 117, in debug_wrapper
    compiled_gm = compiler_fn(gm, example_inputs)
  File "/data/users/williamwen/pytorch/torch/__init__.py", line 1657, in __call__
    return compile_fx(model_, inputs_, config_patches=self.config)
  File "/data/users/williamwen/pytorch/torch/_inductor/compile_fx.py", line 1168, in compile_fx
    return aot_autograd(
  File "/data/users/williamwen/pytorch/torch/_dynamo/backends/common.py", line 55, in compiler_fn
    cg = aot_module_simplified(gm, example_inputs, **kwargs)
  File "/data/users/williamwen/pytorch/torch/_functorch/aot_autograd.py", line 4938, in aot_module_simplified
    compiled_fn = create_aot_dispatcher_function(
  File "/data/users/williamwen/pytorch/torch/_dynamo/utils.py", line 244, in time_wrapper
    r = func(*args, **kwargs)
  File "/data/users/williamwen/pytorch/torch/_functorch/aot_autograd.py", line 4478, in create_aot_dispatcher_function
    compiled_fn = compiler_fn(flat_fn, fake_flat_args, aot_config, fw_metadata=fw_metadata)
  File "/data/users/williamwen/pytorch/torch/_functorch/aot_autograd.py", line 2813, in aot_wrapper_dedupe
    return compiler_fn(flat_fn, leaf_flat_args, aot_config, fw_metadata=fw_metadata)
  File "/data/users/williamwen/pytorch/torch/_functorch/aot_autograd.py", line 2999, in aot_wrapper_synthetic_base
    return compiler_fn(flat_fn, flat_args, aot_config, fw_metadata=fw_metadata)
  File "/data/users/williamwen/pytorch/torch/_functorch/aot_autograd.py", line 3700, in aot_dispatch_autograd
    fx_g, joint_inputs, maybe_subclass_meta = aot_dispatch_autograd_graph(flat_fn, flat_args, aot_config, fw_metadata=fw_metadata)
  File "/data/users/williamwen/pytorch/torch/_functorch/aot_autograd.py", line 3680, in aot_dispatch_autograd_graph
    fx_g = create_graph(joint_fn_to_trace, updated_joint_inputs, aot_config=aot_config)
  File "/data/users/williamwen/pytorch/torch/_functorch/aot_autograd.py", line 1943, in create_graph
    fx_g = make_fx(f, decomposition_table=aot_config.decompositions)(*args)
  File "/data/users/williamwen/pytorch/torch/fx/experimental/proxy_tensor.py", line 869, in wrapped
    t = dispatch_trace(wrap_key(func, args, fx_tracer, pre_dispatch), tracer=fx_tracer, concrete_args=tuple(phs))
  File "/data/users/williamwen/pytorch/torch/_compile.py", line 24, in inner
    return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
  File "/data/users/williamwen/pytorch/torch/_dynamo/eval_frame.py", line 489, in _fn
    return fn(*args, **kwargs)
  File "/data/users/williamwen/pytorch/torch/_dynamo/external_utils.py", line 17, in inner
    return fn(*args, **kwargs)
  File "/data/users/williamwen/pytorch/torch/fx/experimental/proxy_tensor.py", line 481, in dispatch_trace
    graph = tracer.trace(root, concrete_args)
  File "/data/users/williamwen/pytorch/torch/_dynamo/eval_frame.py", line 489, in _fn
    return fn(*args, **kwargs)
  File "/data/users/williamwen/pytorch/torch/_dynamo/external_utils.py", line 17, in inner
    return fn(*args, **kwargs)
  File "/data/users/williamwen/pytorch/torch/fx/_symbolic_trace.py", line 821, in trace
    (self.create_arg(fn(*args)),),
  File "/data/users/williamwen/pytorch/torch/fx/_symbolic_trace.py", line 688, in flatten_fn
    tree_out = root_fn(*tree_args)
  File "/data/users/williamwen/pytorch/torch/fx/experimental/proxy_tensor.py", line 517, in wrapped
    out = f(*tensors)
  File "/data/users/williamwen/pytorch/torch/_functorch/aot_autograd.py", line 1929, in joint_helper
    return functionalized_f_helper(primals, tangents)
  File "/data/users/williamwen/pytorch/torch/_functorch/aot_autograd.py", line 1882, in functionalized_f_helper
    f_outs = fn(*f_args)
  File "/data/users/williamwen/pytorch/torch/_functorch/aot_autograd.py", line 1850, in inner_fn_with_anomaly
    return inner_fn(*args)
  File "/data/users/williamwen/pytorch/torch/_functorch/aot_autograd.py", line 1833, in inner_fn
    backward_out = torch.autograd.grad(
  File "/data/users/williamwen/pytorch/torch/autograd/__init__.py", line 411, in grad
    result = Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [1, 12, 16384]], which is output 0 of ReluBackward0, is at version 1; expected version 0 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!
bdhirsh commented 9 months ago

@Lokiiiiii can you re-open if you're still seeing an issue? David's smaller repro above no longer fails with the original error, as Yanbo pointed out. The new error is actually because the minimized repro isn't quite valid - even in eager mode, that code will fail if you call out.sum().backward(), because the repro code is mutating the output of relu(), which was saved for backward.