huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
132.31k stars 26.35k forks source link

torch.onnx.export failure for llava model #33637

Open symphonylyh opened 7 hours ago

symphonylyh commented 7 hours ago

System Info

transformers version == 4.42.4 works transformers version >= 4.43.0 all fails

Who can help?

No response

Information

Tasks

Reproduction

Steps to reproduce:

import torch
from transformers import LlavaForConditionalGeneration

def export_onnx(model,
                input,
                onnx_dir,
                onnx_name='model.onnx',
                input_names=['input'],
                output_names=['output'],
                dynamic_axes={'input': {
                    0: 'batch'
                }}):
    os.makedirs(onnx_dir, exist_ok=True)
    torch.onnx.export(model,
                      input,
                      f'{onnx_dir}/{onnx_name}',
                      opset_version=17,
                      input_names=input_names,
                      output_names=output_names,
                      dynamic_axes=dynamic_axes)

class LlavaVisionWrapper(torch.nn.Module):

    def __init__(self, tower, projector, feature_layer):
        super().__init__()
        self.tower = tower
        self.projector = projector
        self.feature_layer = feature_layer

    def forward(self, image):
        all_hidden_states = self.tower(
            image, output_hidden_states=True).hidden_states
        features = all_hidden_states[self.feature_layer][:, 1:]
        return self.projector(features)

model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf")
wrapper = LlavaVisionWrapper(
model.vision_tower,
model.multi_modal_projector,
model.config.vision_feature_layer)

raw_image = Image.new('RGB', [10, 10])  # dummy image
image = processor(text="dummy", images=raw_image,
                  return_tensors="pt")['pixel_values']

export_onnx(wrapper, image, 'tmp/onnx')

Leads to error

line 116, in export_onnx
    torch.onnx.export(model,
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py", line 511, in export
    _export(
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py", line 1607, in _export
    graph, params_dict, torch_out = _model_to_graph(
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py", line 1133, in _model_to_graph
    graph = _optimize_graph(
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py", line 672, in _optimize_graph
    graph = _C._jit_pass_onnx(graph, operator_export_type)
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py", line 1956, in _run_symbolic_function
    return symbolic_fn(graph_context, *inputs, **attrs)
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/symbolic_helper.py", line 291, in wrapper
    return fn(g, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/symbolic_opset14.py", line 176, in scaled_dot_product_attention
    query_scaled = g.op("Mul", query, g.op("Sqrt", scale))
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/_internal/jit_utils.py", line 92, in op
    return _add_op(self, opname, *raw_args, outputs=outputs, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/_internal/jit_utils.py", line 243, in _add_op
    inputs = [_const_if_tensor(graph_context, arg) for arg in args]
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/_internal/jit_utils.py", line 243, in <listcomp>
    inputs = [_const_if_tensor(graph_context, arg) for arg in args]
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/_internal/jit_utils.py", line 275, in _const_if_tensor
    return _add_op(graph_context, "onnx::Constant", value_z=arg)
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/_internal/jit_utils.py", line 251, in _add_op
    node = _create_node(
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/_internal/jit_utils.py", line 311, in _create_node
    _add_attribute(node, key, value, aten=aten)
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/_internal/jit_utils.py", line 362, in _add_attribute
    return getattr(node, f"{kind}_")(name, value)
TypeError: z_(): incompatible function arguments. The following argument types are supported:
    1. (self: torch._C.Node, arg0: str, arg1: torch.Tensor) -> torch._C.Node

Only occurs >= 4.43.0

Expected behavior

onnx export should work

LysandreJik commented 5 hours ago

Thanks @symphonylyh! Pinging @xenova for a quick answer when he can