Open KarelZe opened 2 months ago
We don't have encoder-decoder support yet. Let me take a deep into at your model and see how we can support it best.
@yufenglee Thanks for your update. Would be really cool to see this happen🤗
I'll provide a build script for the onnx conversion for a headstart. Please let me know if I can help with implementation/test.
@yufenglee Here's the build.py
script I used for conversion. Please let me know, if you have any questions. 👍
"""Build script for DONUT transformer.
Partly adapted from:
https://huggingface.co/microsoft/Phi-3-vision-128k-instruct-onnx-cuda/blob/main/onnx/builder.py
"""
import logging
import shutil
import subprocess
import sys
from pathlib import Path
import numpy as np
import onnx
import onnxruntime as ort
from onnxruntime_extensions import gen_processing_models, get_library_path
from onnxruntime_extensions.tools.pre_post_processing import (
ChannelsLastToChannelsFirst,
ImageBytesToFloat,
LetterBox,
Normalize,
PrePostProcessor,
Resize,
Unsqueeze,
create_named_value,
)
from PIL import Image
from transformers import AutoConfig, DonutProcessor, VisionEncoderDecoderModel
logger = logging.getLogger(__name__)
output_dir = Path("output/")
cache_dir = Path("cache_dir/")
path_model = Path(
"/path/to/model/"
)
path_test_img = Path("/path/to/test/img.png")
precision = "fp16"
# pipeline in onnx extension requires at least 16 better 18.
opset = 18
def export_model():
"""Export encoder + decoder.
see:
https://huggingface.co/docs/optimum/exporters/onnx/usage_guides/export_a_model
"""
subprocess.run([
"optimum-cli",
"export",
"onnx",
"--model",
path_model,
output_dir / "model_init_export",
"--task",
"image-to-text-with-past",
"--framework",
"pt",
"--opset",
str(opset),
])
def optimize_encoder():
"""Optimize encoder."""
filename = "encoder_model.onnx"
temp_folder_1 = output_dir / "model_init_export"
fpath_1 = temp_folder_1 / filename
onnx.checker.check_model(fpath_1)
onnx.shape_inference.infer_shapes_path(fpath_1)
onnx_model = onnx.load_model(fpath_1, load_external_data=True)
temp_folder_2 = output_dir / "encoder_after_export"
temp_folder_2.mkdir(exist_ok=True)
fpath_2 = temp_folder_2 / filename
onnx.save_model(
onnx_model,
fpath_2,
save_as_external_data=True,
all_tensors_to_one_file=True,
location=f"{filename}.data",
size_threshold=0,
convert_attribute=False,
)
temp_folder_3 = output_dir / "encoder_after_opt"
temp_folder_3.mkdir(exist_ok=True)
fpath_3 = temp_folder_3 / filename
subprocess.run([
f"{sys.executable}",
"-m",
"onnxruntime.transformers.optimizer",
"--input",
fpath_2,
"--output",
fpath_3,
"--model_type",
"swin",
"--num_heads",
str(0), # In config 4 8 16 32 --> Use 0 is auto-discover from graph
"--hidden_size",
str(0), # 0 = auto-discover
"--use_external_data_format",
"--opt_level",
str(0),
])
shutil.rmtree(temp_folder_2)
fpath_4 = output_dir / filename
cmd = [
f"{sys.executable}",
"-m",
"onnxruntime.quantization.matmul_4bits_quantizer",
"--input_model",
fpath_3,
"--output_model",
fpath_4,
"--block_size",
str(32),
]
if precision == "fp32":
cmd.extend(["--accuracy_level", str(4)])
subprocess.run(cmd)
shutil.rmtree(temp_folder_3)
def optimize_decoder():
"""Optimize decoder.
Adapted from:
https://huggingface.co/microsoft/Phi-3-vision-128k-instruct-onnx-cuda/blob/main/onnx/builder.py
"""
filename = "decoder_model_merged.onnx"
temp_folder_1 = output_dir / "model_init_export"
fpath_1 = temp_folder_1 / filename
onnx.checker.check_model(fpath_1)
onnx.shape_inference.infer_shapes_path(fpath_1)
onnx_model = onnx.load_model(fpath_1, load_external_data=True)
temp_folder_2 = output_dir / "decoder_after_export"
temp_folder_2.mkdir(exist_ok=True)
fpath_2 = temp_folder_2 / filename
onnx.save_model(
onnx_model,
fpath_2,
save_as_external_data=True,
all_tensors_to_one_file=True,
location=f"{filename}.data",
size_threshold=0,
convert_attribute=False,
)
temp_folder_3 = output_dir / "decoder_after_opt"
temp_folder_3.mkdir(exist_ok=True)
fpath_3 = temp_folder_3 / filename
subprocess.run([
f"{sys.executable}",
"-m",
"onnxruntime.transformers.optimizer",
"--input",
fpath_2,
"--output",
fpath_3,
"--model_type",
"bart",
"--num_heads",
str(config.decoder.decoder_attention_heads),
"--hidden_size",
str(config.decoder.d_model),
"--use_external_data_format",
"--opt_level",
str(0),
])
shutil.rmtree(temp_folder_2)
fpath_4 = output_dir / filename
cmd = [
f"{sys.executable}",
"-m",
"onnxruntime.quantization.matmul_4bits_quantizer",
"--input_model",
fpath_3,
"--output_model",
fpath_4,
"--block_size",
str(32),
]
if precision == "fp32":
cmd.extend(["--accuracy_level", str(4)])
subprocess.run(cmd)
shutil.rmtree(temp_folder_3)
def build_tokenizer() -> None:
"""Get sentence piece tokenizer from DONUT processor."""
onnx_tokenizer = gen_processing_models(tokenizer, opset=opset, pre_kwargs={})[0]
fpath_tokenizer = output_dir / "tokenizer.onnx"
with fpath_tokenizer.open(mode="wb") as f:
f.write(onnx_tokenizer.SerializeToString())
def build_img_preprocessor() -> None:
"""Build image processor.
Adapted from:
https://github.com/microsoft/onnxruntime-extensions/pull/478/files#diff-8f875d92e23f555946efe7bec0ccdefde80c06a4b74b595071961ec4e0f84f5d
For operations and their order see:
https://github.com/huggingface/transformers/blob/v4.42.0/src/transformers/models/donut/image_processing_donut.py#L54
Raises
------
NotImplementedError: do_thumbnail not yet implemented
NotImplementedError: do_align_long_axis not yet implemented
NotImplementedError: resizing methods !=2 not implemented
"""
pixel_values_in = [
create_named_value("pixel_values", onnx.TensorProto.UINT8, ["h", "w", 3])
]
pipeline = PrePostProcessor(pixel_values_in, onnx_opset=opset)
steps = []
size = (
image_processor_config["size"]["height"],
image_processor_config["size"]["width"],
)
if image_processor_config["do_align_long_axis"]:
logger.warning(
"'do_align_long_axis' is not yet implemented. This will lead to a performance degradation for rotated images."
)
if image_processor_config["do_resize"]:
# 2 = BILINEAR
# see: https://pillow.readthedocs.io/en/stable/reference/Image.html#PIL.Image.Resampling.BICUBIC
if image_processor_config["resample"] != 2:
logger.warning(
"resampling method '%s' not supported. Resampling with 2=BILINEAR.",
image_processor_config["resample"],
)
steps.append(
Resize(
size,
layout="HWC",
name="do_resize",
policy="not_larger",
)
)
if image_processor_config["do_thumbnail"]:
logger.warning("'do_thumbnail' is not yet implemented.")
if image_processor_config["do_pad"]:
steps.append(
LetterBox(target_shape=size, fill_value=0, name="do_pad", layout="HWC")
)
if image_processor_config["do_rescale"]:
steps.append(
ImageBytesToFloat(
image_processor_config["rescale_factor"], name="do_rescale"
)
)
if image_processor_config["do_normalize"]:
mean_std = list(
zip(
image_processor_config["image_mean"],
image_processor_config["image_std"],
)
)
steps.append(Normalize(mean_std, layout="HWC", name="do_normalize"))
steps.extend([
ChannelsLastToChannelsFirst(name="RGBImageCHW"), # HWC to CHW
Unsqueeze([0], name="unsqueeze"), # add batch dimension, CHW --> 1CHW
])
pipeline.add_pre_processing(steps)
pixel_values_out = [
onnx.helper.make_tensor_value_info(
"pixel_values_out", onnx.TensorProto.FLOAT, [1, 3, *size]
)
]
g = onnx.helper.make_graph(
[onnx.helper.make_node("Identity", ["pixel_values"], ["pixel_values_out"])],
"empty",
pixel_values_in,
pixel_values_out,
)
onnx_import = onnx.helper.make_operatorsetid("", opset)
ir_version = onnx.helper.find_min_ir_version_for([onnx_import])
model = onnx.helper.make_model_gen_version(
g, opset_imports=[onnx_import], ir_version=ir_version
)
new_model = pipeline.run(model)
new_model.doc_string = "Donut-like image pre-processor."
new_model.graph.doc_string = ""
temp_folder_1 = output_dir / "img_processor_after_export"
temp_folder_1.mkdir(exist_ok=True)
filename = "preprocessor.onnx"
fpath_1 = temp_folder_1 / filename
onnx.save_model(new_model, fpath_1)
onnx.checker.check_model(fpath_1)
onnx.shape_inference.infer_shapes_path(fpath_1)
subprocess.run([
f"{sys.executable}",
"-m",
"onnxoptimizer",
fpath_1,
output_dir / filename,
])
def test_components():
"""Test tokenizer, image processor, and fused image processor with encoder."""
input_text = "<s_name>YOUR_NAME</s_name>"
sess_options = ort.SessionOptions()
sess_options.register_custom_ops_library(get_library_path())
session = ort.InferenceSession(
output_dir / "tokenizer.onnx",
sess_options=sess_options,
providers=["CPUExecutionProvider"],
)
input_feed = {"inputs": np.asarray([input_text])}
outputs = session.run(["token_indices"], input_feed)
print("token_ids", outputs[0])
print(tokenizer(input_text))
session = ort.InferenceSession(
output_dir / "preprocessor.onnx",
sess_options,
providers=["CPUExecutionProvider"],
)
outputs = session.run(
["pixel_values_out"], {"pixel_values": np.array(Image.open(path_test_img))}
)
# remove batch dim, CHW -> HWC
# [-1, 1] -> [0, 1] * 255 for visualization
img = np.squeeze(outputs[0]).transpose((1, 2, 0))
img = np.uint8((img - img.min()) / (img.max() - img.min()) * 255)
Image.fromarray(img).save(
output_dir / "test_img_from_onnx_processor.png", format="PNG"
)
session = ort.InferenceSession(
output_dir / "encoder_with_img_processor.onnx",
sess_options,
providers=["CPUExecutionProvider"],
)
outputs = session.run(
["last_hidden_state"], {"pixel_values": np.array(Image.open(path_test_img))}
)
print("last_hidden_state", outputs[0])
print("last_hidden-state (shape)", outputs[0].shape)
def merge_img_processor_encoder():
"""Merge image processor and encoder into a single graph."""
model_preprocessor = onnx.load(output_dir / "preprocessor.onnx")
model_encoder = onnx.load(output_dir / "encoder_model.onnx")
output_name_from_img_processor = model_preprocessor.graph.output[0].name
input_name_of_encoder = model_encoder.graph.input[0].name
merged_model = onnx.compose.merge_models(
model_preprocessor,
model_encoder,
io_map=[(output_name_from_img_processor, input_name_of_encoder)],
)
filename = "encoder_with_img_processor.onnx"
onnx.save(
merged_model,
output_dir / filename,
save_as_external_data=True,
all_tensors_to_one_file=True,
location=f"{filename}.data",
size_threshold=0,
convert_attribute=False,
)
if __name__ == "__main__":
config = AutoConfig.from_pretrained(path_model)
processor = DonutProcessor.from_pretrained(path_model)
image_processor = processor.image_processor
image_processor_config = image_processor.to_dict()
tokenizer = processor.tokenizer
model = VisionEncoderDecoderModel.from_pretrained(path_model)
config = AutoConfig.from_pretrained(path_model)
export_model()
build_tokenizer()
optimize_encoder()
optimize_decoder()
build_img_preprocessor()
merge_img_processor_encoder()
test_components()
@yufenglee Is there any update? Could you give me any guidance, how I could help to make this feature happen?
Would it be not possible for your case to load the input beforehand by using something like this:
onnxruntime_genai.GeneratorParams.set_model_input(name: str, value: [])
and not including the actual input name in the generation config?
I had a custom model made by myself which loaded some of the inputs before starting the generation, however if you have a model which needs to continously load and update the inputs, this might not be supported as of yet. Perhaps maybe consider splitting the model in more parts.
background
My question is about executing encoder-decoder models with onnx genai runtime. My goal is to convert the DONUT transformer https://arxiv.org/abs/2111.15664, a sequence-to-sequence transformer for document understanding with swin encoder and mbart decoder to onnx and run it using onnxruntime-genai.
I managed to convert the individual components to onnx. Now I'm stuck at writing a
genai_config.json
suitable for encoder-decoder models.steps
I started with the huggingface implementation of DONUT (https://huggingface.co/docs/transformers/v4.42.0/en/model_doc/donut#overview) and converted the encoder and merged decoder (with kv-cache) to onnx using optimum https://huggingface.co/docs/optimum/index. I converted the DONUT processor, which consists of an image processor and sentencepiece-tokenizer to onnx using onnx runtime extensions (https://github.com/microsoft/onnxruntime-extensions/blob/main/onnxruntime_extensions/tools/). I merged the image processor and swin-encoder into a single graph. I can provide the conversion scripts if needed.
My components have the following inputs/outputs shapes:
tokenizer:
encoder with image processor:
decoder: (more past keys + values)
question
However, I am stuck at manually writing/loading a suitable
genai_config.json
, with the hidden states from the encoder for use in the decoder's attention mechanism. I'm aware of https://onnxruntime.ai/docs/genai/reference/config.html to write configs, but it seemingly focuses on decoder-only models. I'm also aware ofmake_genai_config
https://github.com/microsoft/onnxruntime-genai/blob/c7eba3c63a454edd6662eb007ff397d1146cc081/src/python/py/models/builder.py for auto-config generation of supported models.I tried the following config:
When loading the config I receive the error
RuntimeError: Error encountered while parsing 'output/genai_config.json' JSON Error: Unknown value: encoder_hidden_states at line 15 index 64
. As far as I can tell from https://github.com/microsoft/onnxruntime-genai/blob/c7eba3c63a454edd6662eb007ff397d1146cc081/src/python/py/models/builder.py#L71 it's currently not possible to pass encoder hidden states to the model as inputs.Do you have any plans of extending onnx-genai runtime for encoder-decoder models? Could you please give me a hint/advice how to work around this?
Thank you for your assistance.