pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
81.82k stars 21.96k forks source link

[TS2EP] Failing on longformer #130008

Open justinchuby opened 2 months ago

justinchuby commented 2 months ago

🐛 Describe the bug

I am testing the logic on huggingface longformer. This exact repro is a little complex because I have changed some source code in transformers, but the error stack may give you some ideas?

  File "/home/justinchu/dev/torch-onnx/src/torch_onnx/_torchscript_converter.py", line 870, in convert
    ep = self.retrace_as_exported_program(
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

  File "/home/justinchu/dev/torch-onnx/src/torch_onnx/_torchscript_converter.py", line 879, in retrace_as_exported_program
    ep = torch.export._trace._export(
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^

  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.11/site-packages/torch/export/_trace.py", line 991, in wrapper
    raise e

  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.11/site-packages/torch/export/_trace.py", line 974, in wrapper
    ep = fn(*args, **kwargs)
         ^^^^^^^^^^^^^^^^^^^

  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.11/site-packages/torch/export/exported_program.py", line 91, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^

  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.11/site-packages/torch/export/_trace.py", line 1507, in _export
    export_artifact = export_func(
                      ^^^^^^^^^^^^

  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.11/site-packages/torch/export/_trace.py", line 1366, in _non_strict_export
    aten_export_artifact = _export_to_aten_ir(
                           ^^^^^^^^^^^^^^^^^^^

  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.11/site-packages/torch/export/_trace.py", line 623, in _export_to_aten_ir
    gm, graph_signature = transform(aot_export_module)(
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.11/site-packages/torch/export/_trace.py", line 1317, in _aot_export_non_strict
    gm, sig = aot_export(wrapped_mod, args, kwargs=kwargs, **flags)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 1147, in aot_export_module
    fx_g, metadata, in_spec, out_spec = _aot_export_function(
                                        ^^^^^^^^^^^^^^^^^^^^^

  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 1366, in _aot_export_function
    fx_g, meta = create_aot_dispatcher_function(
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 231, in time_wrapper
    r = func(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^

  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 571, in create_aot_dispatcher_function
    fw_metadata = run_functionalized_fw_and_collect_metadata(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/collect_metadata_analysis.py", line 163, in inner
    flat_f_outs = f(*flat_f_args)
                  ^^^^^^^^^^^^^^^

  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/utils.py", line 178, in flat_fn
    tree_out = fn(*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^

  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 764, in functional_call
    out = mod(*args[params_len:], **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1566, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1575, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.11/site-packages/torch/export/_trace.py", line 1300, in forward
    tree_out = torch.fx.Interpreter(self._export_root).run(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.11/site-packages/torch/fx/interpreter.py", line 146, in run
    self.env[node] = self.run_node(node)
                     ^^^^^^^^^^^^^^^^^^^

  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.11/site-packages/torch/fx/interpreter.py", line 203, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.11/site-packages/torch/fx/interpreter.py", line 275, in call_function
    return target(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^

  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.11/site-packages/torch/_ops.py", line 670, in __call__
    return self_._op(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^

  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.11/site-packages/torch/_subclasses/functional_tensor.py", line 468, in __torch_dispatch__
    outs_unwrapped = func._op_dk(
                     ^^^^^^^^^^^^

  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.11/site-packages/torch/utils/_stats.py", line 21, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^

  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1061, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1450, in dispatch
    return self._cached_dispatch_impl(func, types, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1153, in _cached_dispatch_impl
    output = self._dispatch_impl(func, types, args, kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1690, in _dispatch_impl
    return decomposition_table[func](*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.11/site-packages/torch/_prims_common/wrappers.py", line 266, in _fn
    result = fn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^

  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.11/site-packages/torch/_refs/__init__.py", line 2859, in constant_pad_nd
    if pad[pad_idx + 1] < 0:
       ^^^^^^^^^^^^^^^^^^^^

  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.11/site-packages/torch/__init__.py", line 620, in __bool__
    return self.node.bool_()
           ^^^^^^^^^^^^^^^^^

  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.11/site-packages/torch/fx/experimental/sym_node.py", line 476, in bool_
    return self.guard_bool("", 0)
           ^^^^^^^^^^^^^^^^^^^^^^

  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.11/site-packages/torch/fx/experimental/sym_node.py", line 414, in guard_bool
    r = self.shape_env.evaluate_expr(self.expr, self.hint, fx_node=self.fx_node)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.11/site-packages/torch/fx/experimental/recording.py", line 245, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^

  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.11/site-packages/torch/fx/experimental/symbolic_shapes.py", line 5226, in evaluate_expr
    raise self._make_data_dependent_error(

torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression u0 < 0 (unhinted: u0 < 0).  (Size-like symbols: none)

Potential framework code culprit (scroll up for full backtrace):
  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.11/site-packages/torch/_refs/__init__.py", line 2859, in constant_pad_nd
    if pad[pad_idx + 1] < 0:

For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing

For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1

While executing %pad_default : [num_users=2] = call_function[target=torch.ops.aten.pad.default](args = (%input_ids_1, [0, %_local_scalar_dense_default], constant, 1.0), kwargs = {})
Original traceback:
None

Versions

main

@jiashenC

cc @ezyang @anijain2305 @chauhang @penguinwu @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4

ezyang commented 2 months ago

torch._check_is_size(pad_idx) or whatever u0 is

@avikchaudhuri another reason to implement your stack walker lol

justinchuby commented 2 months ago

Was this change needed in torch/_refs or in user code? Sorry I am not yet familiar with _check_is_size

ezyang commented 2 months ago

User code. I guess it's also possible that maybe some framework code could use a check_is_size but if pad_idx is only passed into the indexing operation we can't do that because negative indices are supported there.