huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
133.22k stars 26.61k forks source link

[wav2vec2] torch.export error in Wav2Vec2ForCTC #34022

Open chrsmcgrr opened 3 days ago

chrsmcgrr commented 3 days ago

System Info

Who can help?

@ylacombe, @eustlb

Information

Tasks

Reproduction

Running the following script:

from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2ForCTC
import torch

MODEL_ID = "jonatasgrosman/wav2vec2-large-xlsr-53-english"
model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
model.eval()

example_inputs = torch.rand(1, 76802)
example_mask = torch.ones(1, 75000)

exported_program = torch.export.export(model, (example_inputs,example_mask,))

print(exported_program.graph)

Causes the following error:

Traceback (most recent call last):
  File "/home/user/transformers/reproducer.py", line 11, in <module>
    exported_program = torch.export.export(model, (example_inputs,example_mask,))
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/conda/envs/ame/lib/python3.11/site-packages/torch/export/__init__.py", line 174, in export
    return _export(
           ^^^^^^^^
  File "/home/user/conda/envs/ame/lib/python3.11/site-packages/torch/export/_trace.py", line 945, in wrapper
    raise e
  File "/home/user/conda/envs/ame/lib/python3.11/site-packages/torch/export/_trace.py", line 928, in wrapper
    ep = fn(*args, **kwargs)
         ^^^^^^^^^^^^^^^^^^^
  File "/home/user/conda/envs/ame/lib/python3.11/site-packages/torch/export/exported_program.py", line 89, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/user/conda/envs/ame/lib/python3.11/site-packages/torch/export/_trace.py", line 1455, in _export
    aten_export_artifact = export_func(
                           ^^^^^^^^^^^^
  File "/home/user/conda/envs/ame/lib/python3.11/site-packages/torch/export/_trace.py", line 1158, in _strict_export
    aten_export_artifact = _export_to_aten_ir(
                           ^^^^^^^^^^^^^^^^^^^
  File "/home/user/conda/envs/ame/lib/python3.11/site-packages/torch/export/_trace.py", line 583, in _export_to_aten_ir
    gm, graph_signature = transform(aot_export_module)(
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/conda/envs/ame/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 1131, in aot_export_module
    fx_g, metadata, in_spec, out_spec = _aot_export_function(
                                        ^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/conda/envs/ame/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 1350, in _aot_export_function
    fx_g, meta = create_aot_dispatcher_function(
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/conda/envs/ame/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 231, in time_wrapper
    r = func(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/conda/envs/ame/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 687, in create_aot_dispatcher_function
    compiled_fn, fw_metadata = compiler_fn(
                               ^^^^^^^^^^^^
  File "/home/user/conda/envs/ame/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 95, in aot_dispatch_export
    graph, _, _ = aot_dispatch_base_graph(
                  ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/conda/envs/ame/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py", line 138, in aot_dispatch_base_graph
    fw_module = _create_graph(
                ^^^^^^^^^^^^^^
  File "/home/user/conda/envs/ame/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py", line 46, in _create_graph
    fx_g = make_fx(
           ^^^^^^^^
  File "/home/user/conda/envs/ame/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py", line 1421, in wrapped
    return make_fx_tracer.trace(f, *args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/conda/envs/ame/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py", line 1367, in trace
    return self._trace_inner(f, *args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/conda/envs/ame/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py", line 1354, in _trace_inner
    t = dispatch_trace(
        ^^^^^^^^^^^^^^^
  File "/home/user/conda/envs/ame/lib/python3.11/site-packages/torch/_compile.py", line 31, in inner
    return disable_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/conda/envs/ame/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 600, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/user/conda/envs/ame/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py", line 642, in dispatch_trace
    graph = tracer.trace(root, concrete_args)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/conda/envs/ame/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py", line 1019, in trace
    res = super().trace(root, concrete_args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/conda/envs/ame/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 600, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/user/conda/envs/ame/lib/python3.11/site-packages/torch/fx/_symbolic_trace.py", line 822, in trace
    (self.create_arg(fn(*args)),),
                     ^^^^^^^^^
  File "/home/user/conda/envs/ame/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py", line 660, in wrapped
    out = f(*tensors)
          ^^^^^^^^^^^
  File "<string>", line 1, in <lambda>
  File "/home/user/conda/envs/ame/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 388, in _functionalized_f_helper
    f_outs = fn(*f_args)
             ^^^^^^^^^^^
  File "/home/user/conda/envs/ame/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 72, in inner_fn
    outs = fn(*args)
           ^^^^^^^^^
  File "/home/user/conda/envs/ame/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/utils.py", line 178, in flat_fn
    tree_out = fn(*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^
  File "/home/user/conda/envs/ame/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 744, in functional_call
    out = PropagateUnbackedSymInts(mod).run(
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/conda/envs/ame/lib/python3.11/site-packages/torch/fx/interpreter.py", line 146, in run
    self.env[node] = self.run_node(node)
                     ^^^^^^^^^^^^^^^^^^^
  File "/home/user/conda/envs/ame/lib/python3.11/site-packages/torch/fx/experimental/symbolic_shapes.py", line 5461, in run_node
    result = super().run_node(n)
             ^^^^^^^^^^^^^^^^^^^
  File "/home/user/conda/envs/ame/lib/python3.11/site-packages/torch/fx/interpreter.py", line 203, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/conda/envs/ame/lib/python3.11/site-packages/torch/fx/interpreter.py", line 275, in call_function
    return target(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/conda/envs/ame/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py", line 705, in __torch_function__
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/conda/envs/ame/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py", line 728, in __torch_function__
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/conda/envs/ame/lib/python3.11/site-packages/torch/_ops.py", line 781, in handler
    return torch._library.utils.handle_dispatch_mode(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/conda/envs/ame/lib/python3.11/site-packages/torch/_library/utils.py", line 244, in handle_dispatch_mode
    return curr_mode.__torch_dispatch__(op_overload, overload_types, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/conda/envs/ame/lib/python3.11/site-packages/torch/_subclasses/functional_tensor.py", line 468, in __torch_dispatch__
    outs_unwrapped = func._op_dk(
                     ^^^^^^^^^^^^
RuntimeError: cannot mutate tensors with frozen storage

While executing %setitem_1 : [num_users=0] = call_function[target=operator.setitem](args = (%hidden_states_37, %invert, 0), kwargs = {})
Original traceback:
  File "/home/user/transformers/src/transformers/models/wav2vec2/modeling_wav2vec2.py", line 2229, in forward
    outputs = self.wav2vec2(
  File "/home/user/conda/envs/ame/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/user/transformers/src/transformers/models/wav2vec2/modeling_wav2vec2.py", line 1824, in forward
    encoder_outputs = self.encoder(
  File "/home/user/conda/envs/ame/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/user/transformers/src/transformers/models/wav2vec2/modeling_wav2vec2.py", line 1111, in forward
    hidden_states[~expand_attention_mask] = 0

Expected behavior

torch.export completes without raising an exception

chrsmcgrr commented 3 days ago

I already have a pragmatic solution and would like to get feedback on it.

I would replace the following line in the model:

hidden_states[~expand_attention_mask] = 0

with

hidden_states = hidden_states*expand_attention_mask.to(hidden_states.dtype)

This avoids the issue in pytorch. Though the real solution does lie within pytorch. But I have yet to create a small reproducer.

For now this change will unblock the model. I will open a PR shortly.