microsoft / onnxruntime

ONNX Runtime: cross-platform, high performance ML inferencing and training accelerator
https://onnxruntime.ai
MIT License
14.53k stars 2.91k forks source link

[Documentation] Both new LLama-7B examples are now broken #19040

Closed ricpruss closed 9 months ago

ricpruss commented 9 months ago

Describe the documentation issue

I am pretty sure these were working a few weeks ago but now both of the documented ways of exporting LLama-7B are broken.

The Readme at https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/python/tools/transformers/models/llama gives.

# From source: $ git clone https://github.com/microsoft/onnxruntime $ cd onnxruntime/onnxruntime/python/tools/transformers/ $ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b

Dies with a param mismatch in the rotary encoder:

Traceback (most recent call last):
  File "/home/ubuntu/github/onnxruntime/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py", line 1007, in main
    parity_check(parity_cmd)
  File "/home/ubuntu/github/onnxruntime/onnxruntime/python/tools/transformers/models/llama/llama_parity.py", line 264, in main
    kv_cache_ortvalues = verify_parity(args, config, llama, kv_cache_ortvalues)
  File "/home/ubuntu/github/onnxruntime/onnxruntime/python/tools/transformers/models/llama/llama_parity.py", line 127, in verify_parity
    ort_outputs = ort_model.run(None, inputs)
  File "/home/ubuntu/.cache/pypoetry/virtualenvs/llm-benching-tTPj4svn-py3.10/lib/python3.10/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 220, in run
    return self._sess.run(output_names, input_feed, run_options)
onnxruntime.capi.onnxruntime_pybind11_state.InvalidArgument: [ONNXRuntimeError] : 2 : INVALID_ARGUMENT : Non-zero status code returned while running RotaryEmbedding node. Name:'RotaryEmbedding_0' Status Message: Input 'x' is expected to have 3 dimensions, got 4

and the run from python installed 1.16.3 onnxruntime

# From wheel:
$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b

Dies with:

Traceback (most recent call last):
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/ubuntu/.cache/pypoetry/virtualenvs/llm-benching-tTPj4svn-py3.10/lib/python3.10/site-packages/onnxruntime/transformers/models/llama/convert_to_onnx.py", line 965, in <module>
    main()
  File "/home/ubuntu/.cache/pypoetry/virtualenvs/llm-benching-tTPj4svn-py3.10/lib/python3.10/site-packages/onnxruntime/transformers/models/llama/convert_to_onnx.py", line 802, in main
    run_torchscript_merged_export(args, l_config, llama, rank, world_size)
  File "/home/ubuntu/.cache/pypoetry/virtualenvs/llm-benching-tTPj4svn-py3.10/lib/python3.10/site-packages/onnxruntime/transformers/models/llama/convert_to_onnx.py", line 363, in run_torchscript_merged_export
    torch.onnx.export(
  File "/home/ubuntu/.cache/pypoetry/virtualenvs/llm-benching-tTPj4svn-py3.10/lib/python3.10/site-packages/torch/onnx/utils.py", line 516, in export
    _export(
  File "/home/ubuntu/.cache/pypoetry/virtualenvs/llm-benching-tTPj4svn-py3.10/lib/python3.10/site-packages/torch/onnx/utils.py", line 1613, in _export
    graph, params_dict, torch_out = _model_to_graph(
  File "/home/ubuntu/.cache/pypoetry/virtualenvs/llm-benching-tTPj4svn-py3.10/lib/python3.10/site-packages/torch/onnx/utils.py", line 1139, in _model_to_graph
    graph = _optimize_graph(
  File "/home/ubuntu/.cache/pypoetry/virtualenvs/llm-benching-tTPj4svn-py3.10/lib/python3.10/site-packages/torch/onnx/utils.py", line 677, in _optimize_graph
    graph = _C._jit_pass_onnx(graph, operator_export_type)
  File "/home/ubuntu/.cache/pypoetry/virtualenvs/llm-benching-tTPj4svn-py3.10/lib/python3.10/site-packages/torch/onnx/utils.py", line 1967, in _run_symbolic_function
    raise errors.UnsupportedOperatorError(
torch.onnx.errors.UnsupportedOperatorError: Exporting the operator 'aten::scaled_dot_product_attention' to ONNX opset version 13 is not supported. Support for this operator was added in version 14, try exporting with this version.

This is the nightly pytorch on CPU and latest onnxruntime

In case its a version drama here is a list for you.

attrs                            23.2.0                Classes Without Boilerplate
certifi                          2023.11.17            Python package for providing Mozilla's CA Bundle.
charset-normalizer               3.3.2                 The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet.
coloredlogs                      15.0.1                Colored terminal output for Python's logging module
contextlib2                      21.6.0                Backports and enhancements for the contextlib module
contourpy                        1.2.0                 Python library for calculating contours of 2D quadrilateral grids
cycler                           0.12.1                Composable style cycles
deprecated                       1.2.14                Python @deprecated decorator to deprecate old python classes, functions or methods.
filelock                         3.13.1                A platform independent file lock.
flatbuffers                      23.5.26               The FlatBuffers serialization format for Python
fonttools                        4.47.0                Tools to manipulate font files
fsspec                           2023.12.2             File-system specification
huggingface-hub                  0.20.2                Client library to download and publish models, datasets and other repos on the huggingface.co hub
humanfriendly                    10.0                  Human friendly output for text interfaces using Python
idna                             3.6                   Internationalized Domain Names in Applications (IDNA)
intel-extension-for-transformers 1.3                   Repository of Intel® Intel Extension for Transformers
jinja2                           3.1.2                 A very fast and expressive template engine.
joblib                           1.3.2                 Lightweight pipelining with Python functions
kiwisolver                       1.4.5                 A fast implementation of the Cassowary constraint solver
markupsafe                       2.1.3                 Safely add untrusted strings to HTML/XML markup.
matplotlib                       3.8.2                 Python plotting package
more-itertools                   10.1.0                More routines for operating on iterables, beyond itertools
mpmath                           1.3.0                 Python library for arbitrary-precision floating-point arithmetic
networkx                         3.2.1                 Python package for creating and manipulating graphs and networks
neural-compressor                2.4.1                 Repository of Intel® Neural Compressor
numpy                            1.26.3                Fundamental package for array computing in Python
onnx                             1.15.0                Open Neural Network Exchange
onnxruntime                      1.16.3                ONNX Runtime is a runtime accelerator for Machine Learning models
opencv-python-headless           4.9.0.80              Wrapper package for OpenCV python bindings.
packaging                        23.2                  Core utilities for Python packages
pandas                           2.1.4                 Powerful data structures for data analysis, time series, and statistics
pillow                           10.2.0                Python Imaging Library (Fork)
pluggy                           0.13.1                plugin and hook calling mechanisms for python
prettytable                      3.9.0                 A simple Python library for easily displaying tabular data in a visually appealing ASCII table format
protobuf                         4.25.1                
psutil                           5.9.7                 Cross-platform lib for process and system monitoring in Python.
py                               1.11.0                library with cross-python path, ini-parsing, io, code, log facilities
py-cpuinfo                       9.0.0                 Get CPU info with pure Python
pycocotools                      2.0.7                 Official APIs for the MS-COCO dataset
pyparsing                        3.1.1                 pyparsing module - Classes and methods to define and execute parsing grammars
pytest                           5.4.3                 pytest: simple powerful testing with Python
python-dateutil                  2.8.2                 Extensions to the standard Python datetime module
pytz                             2023.3.post1          World timezone definitions, modern and historical
pyyaml                           6.0.1                 YAML parser and emitter for Python
regex                            2023.12.25            Alternative regular expression module, to replace re.
requests                         2.31.0                Python HTTP for Humans.
safetensors                      0.4.1                 
schema                           0.7.5                 Simple data validation library
scikit-learn                     1.3.2                 A set of python modules for machine learning and data mining
scipy                            1.11.4                Fundamental algorithms for scientific computing in Python
six                              1.16.0                Python 2 and 3 compatibility utilities
sympy                            1.12                  Computer algebra system (CAS) in Python
threadpoolctl                    3.2.0                 threadpoolctl
tokenizers                       0.15.0                
torch                            2.3.0.dev20240107+cpu Tensors and Dynamic neural networks in Python with strong GPU acceleration
tqdm                             4.66.1                Fast, Extensible Progress Meter
transformers                     4.36.2                State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow
typing-extensions                4.9.0                 Backported and Experimental Type Hints for Python 3.8+
tzdata                           2023.4                Provider of IANA time zone data
urllib3                          2.1.0                 HTTP library with thread-safe connection pooling, file post, and more.
wcwidth                          0.2.13                Measures the displayed width of unicode strings in a terminal
wrapt                            1.16.0                Module for decorators, wrappers and monkey patching.

Page / URL

https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/python/tools/transformers/models/llama

ricpruss commented 9 months ago

Okay I did a workaround to get back to what I had when this last worked for me. And managed to successfully run it.

  1. I installed the last torch 2.2.0 from the nightly. pip install https://download.pytorch.org/whl/nightly/cpu/torch-2.2.0.dev20231213%2Bcpu-cp310-cp310-linux_x86_64.whl 2. Checkout the last good release version git checkout -t remotes/origin/rel-1.16.3 3. Do a build that builds the python wheel
    /build.sh --config Release --build_shared_lib --parallel --enable_pybind --skip_tests --build_wheel  --update --build
    cd build/Linux/Release/dist/
    pip install ./onnxruntime-1.17.0-cp310-cp310-linux_x86_64.whl

Then do the convert from the transformers directory as per the original instructions..,

natke commented 9 months ago

Thank you for reporting this issue @ricpruss. The opset error is fixed in main and will be released with the next release. We are looking into the rotary embedding param mismatch, and will keep you posted.

kunal-vaishnavi commented 9 months ago

The rotary embedding error is fixed here and the reason for the opset error is explained here.