huggingface / optimum

🚀 Accelerate training and inference of 🤗 Transformers and 🤗 Diffusers with easy to use hardware optimization tools
https://huggingface.co/docs/optimum/main/
Apache License 2.0
2.53k stars 454 forks source link

Not being able to load LongT5 checkpoint with ORTModelForSeq2SeqLM #406

Open caffeinetoomuch opened 2 years ago

caffeinetoomuch commented 2 years ago

System Info

transformers==4.22.1
optimum==1.4.0
onnx==1.12.0
onnxruntime==1.12.1

Python 3.8.10

Who can help?

@lewtun, @michaelbenayoun

Information

Tasks

Reproduction

from transformers import AutoTokenizer
from optimum.onnxruntime import ORTModelForSeq2SeqLM

model_name = "google/long-t5-tglobal-xl"
tokenizer = AutoTokenizer.from_pretrained(model_name)

# This gives an exception!
model: ORTModelForSeq2SeqLM = ORTModelForSeq2SeqLM.from_pretrained(
    model_name, from_transformers=True, save_dir="./local_temp"
)

print("Loaded!")
tokenizer.save_pretrained("local_onnx")
model.save_pretrained("local_onnx")

Above code snippet causes the following exception:

Traceback (most recent call last):
  File "optimize_with_optimum.py", line 108, in <module>
    model: ORTModelForSeq2SeqLM = ORTModelForSeq2SeqLM.from_pretrained(model_name, from_transformers=True)
  File "/home/jinkoo/.local/lib/python3.8/site-packages/optimum/modeling_base.py", line 228, in from_pretrained
    return cls._from_transformers(
  File "/home/jinkoo/.local/lib/python3.8/site-packages/optimum/onnxruntime/modeling_seq2seq.py", line 441, in _from_transformers
    return cls._from_pretrained(save_dir, **kwargs)
  File "/home/jinkoo/.local/lib/python3.8/site-packages/optimum/onnxruntime/modeling_seq2seq.py", line 316, in _from_pretrained
    model = cls.load_model(
  File "/home/jinkoo/.local/lib/python3.8/site-packages/optimum/onnxruntime/modeling_seq2seq.py", line 213, in load_model
    decoder_session = onnxruntime.InferenceSession(str(decoder_path), providers=[provider])
  File "/home/jinkoo/.local/lib/python3.8/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 347, in __init__
    self._create_inference_session(providers, provider_options, disabled_optimizers)
  File "/home/jinkoo/.local/lib/python3.8/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 395, in _create_inference_session
    sess.initialize_session(providers, provider_options, disabled_optimizers)
  onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Deserialize tensor onnx::MatMul_5263 failed.tensorprotoutils.cc:640 TensorProtoToTensor External initializer: onnx::MatMul_5263 offset: 0 size to read: 41943040 given file_length: 16777216 are out of bounds or can not be read in full.

Expected behavior

ONNX checkpoints of encoder, decoder and decoder with past being generated!

caffeinetoomuch commented 2 years ago

It seems this does not happen with google/long-t5-tglobal-large. Furthermore, I was actually able to load the model as ORTModelForSeq2SeqLM by exporting the XL checkpoint myself and using from_transformers=False. I think some decoder external files are overwritten when it is exporting decoder with past. So, in my case, when I was exporting, I used the separate folders for decoder and decoder with past.