huggingface / optimum

🚀 Accelerate training and inference of 🤗 Transformers and 🤗 Diffusers with easy to use hardware optimization tools
https://huggingface.co/docs/optimum/main/
Apache License 2.0
2.46k stars 436 forks source link

Exporting tinyllama-1.1b using onnxruntime bf16 crashes #1807

Open mgiessing opened 5 months ago

mgiessing commented 5 months ago

System Info

System information:

Container is Debian12 (mambaorg/micromamba)

Host is RHEL9 / ppc64le

$ cat /etc/os-release 
PRETTY_NAME="Debian GNU/Linux 12 (bookworm)"
NAME="Debian GNU/Linux"
VERSION_ID="12"
VERSION="12 (bookworm)"
VERSION_CODENAME=bookworm
ID=debian
HOME_URL="https://www.debian.org/"
SUPPORT_URL="https://www.debian.org/support"
BUG_REPORT_URL="https://bugs.debian.org/"

$ uname -a
Linux b8e04f1032bc 5.14.0-362.18.1.el9_3.ppc64le #1 SMP Mon Jan 29 03:48:20 PST 2024 ppc64le GNU/Linux

Python, Optimum & PyTorch version:

$ python3 -V
Python 3.10.9

$ pip3 list installed | grep optimum
optimum            1.18.1

$ micromamba list pytorch
List of packages in environment: "/opt/conda"

  Name             Version  Build                  Channel 
─────────────────────────────────────────────────────────────
  _pytorch_select  1.0      cpu_2                  rocketce
  pytorch-base     2.0.1    cpu_py310_pb4.21.12_1  rocketce
  pytorch-cpu      2.0.1    py310_1                rocketce

(base) root@b8e04f1032bc:/tmp# micromamba list onnx
List of packages in environment: "/opt/conda"

  Name         Version  Build                           Channel 
──────────────────────────────────────────────────────────────────
  onnx         1.13.1   h25d5be3_py310_pb4.21.12_1      rocketce
  onnxruntime  1.15.1   hd867603_cpu_py310_pb4.21.12_1  rocketce

Who can help?

@michaelbenayoun @JingyaHuang @echarlaix

Information

Tasks

Reproduction (minimal, reproducible, runnable)

Converting to fp32 works without issues, fp16 is not possible since I'm on a CPU only system and bf16 throws the following error:

$ optimum-cli export onnx --model TinyLlama/TinyLlama-1.1B-Chat-v1.0 tinyllama-1b_onnx/ --dtype bf16

[...]

In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
================ Diagnostic Run torch.onnx.export version 2.0.1 ================
verbose: False, log level: Level.ERROR
======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================

Saving external data to one file...
Traceback (most recent call last):
  File "/opt/conda/bin/optimum-cli", line 8, in <module>
    sys.exit(main())
  File "/opt/conda/lib/python3.10/site-packages/optimum/commands/optimum_cli.py", line 163, in main
    service.run()
  File "/opt/conda/lib/python3.10/site-packages/optimum/commands/export/onnx.py", line 261, in run
    main_export(
  File "/opt/conda/lib/python3.10/site-packages/optimum/exporters/onnx/__main__.py", line 351, in main_export
    onnx_export_from_model(
  File "/opt/conda/lib/python3.10/site-packages/optimum/exporters/onnx/convert.py", line 1157, in onnx_export_from_model
    _, onnx_outputs = export_models(
  File "/opt/conda/lib/python3.10/site-packages/optimum/exporters/onnx/convert.py", line 768, in export_models
    export(
  File "/opt/conda/lib/python3.10/site-packages/optimum/exporters/onnx/convert.py", line 902, in export
    config.fix_dynamic_axes(output, device=device, input_shapes=input_shapes, dtype=dtype)
  File "/opt/conda/lib/python3.10/site-packages/optimum/exporters/onnx/base.py", line 306, in fix_dynamic_axes
    session = InferenceSession(model_path.as_posix(), providers=providers, sess_options=session_options)
  File "/opt/conda/lib/python3.10/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 383, in __init__
    self._create_inference_session(providers, provider_options, disabled_optimizers)
  File "/opt/conda/lib/python3.10/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 424, in _create_inference_session
    sess = C.InferenceSession(session_options, self._model_path, True, self._read_config_from_model)
onnxruntime.capi.onnxruntime_pybind11_state.InvalidGraph: [ONNXRuntimeError] : 10 : INVALID_GRAPH : Load model from tinyllama-1b_onnx/model.onnx failed:This is an invalid model. Type Error: Type 'tensor(bfloat16)' of input parameter (/model/Constant_34_output_0) of operator (Where) in node (/model/Where_3) is invalid.

Expected behavior

Convert the model properly to bf16

fxmarty commented 5 months ago

Thank you @mgiessing! It is possible that the ONNX model is valid, but ORT is missing some operators for bf16. It can halso be a bug, I will have a look shortly.

mgiessing commented 5 months ago

Thank you for having a look - this also happened on my Mac M1 with a more recent ORT version (v1.17.1) and also with a different model (deepset/roberta-base-squad2)

fxmarty commented 5 months ago

@mgiessing Where (used in https://github.com/huggingface/transformers/blob/caa5c65db1f4db617cdac2ad667ba62edf94dd98/src/transformers/models/llama/modeling_llama.py#L1086) is not implemented for BF16 dtype in ORT https://github.com/microsoft/onnxruntime/blob/v1.17.1/docs/OperatorKernels.md

However it is valid in ONNX standard: https://github.com/onnx/onnx/blob/main/docs/Operators.md#where

I suggest you to open a feature request in ONNX Runtime repository concerning this to add the support. In the meantime, we could patch Transformers code for this to work in BF16 (avoid the Where op in bf16).

See as well https://github.com/huggingface/optimum/issues/1720#issuecomment-1963838333 that is related and that you are likely to hit as well

If you are using optimum installed from source, a warning is displayed about this:

Exporting the model LlamaForCausalLM in bfloat16 float dtype. After the export, ONNX Runtime InferenceSession with CPU/CUDA execution provider likely does not implement all operators for the bfloat16 data type, and the loading is likely to fail.
mgiessing commented 5 months ago

Thanks for having a look at that :) I'll try to open a request starting next week to address that issue!