Open fadi212 opened 2 years ago
I believe @NielsRogge can help out here
I'm not an ONNX expert, however. Pinging @michaelbenayoun for this.
@michaelbenayoun can you please help here.
I think it might have to do with the fact that your dummy inputs don't have the image field, so the inputs might be off?
It seems to come from the LayoutLMv2Tokenizer
which takes boxes (bbox) as inputs.
Here you are calling super().generate_dummy_inputs
which uses the tokenizer to create dummy inputs, but this does not provide the boxes to the tokenizer, hence the error.
There are two ways of solving this issue:
Hi @michaelbenayoun , I have made the recommended changes in the LayoutLMv2 config file.
# coding=utf-8
# Copyright Microsoft Research and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" LayoutLMv2 model configuration """
from ...configuration_utils import PretrainedConfig
from ...file_utils import is_detectron2_available
from ...utils import logging
from ...onnx import OnnxConfig, PatchingSpec
from typing import Any, List, Mapping, Optional
from transformers import TensorType
from transformers import LayoutLMv2Processor
from datasets import load_dataset
from PIL import Image
from ... import is_torch_available
from collections import OrderedDict
logger = logging.get_logger(__name__)
LAYOUTLMV2_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"layoutlmv2-base-uncased": "https://huggingface.co/microsoft/layoutlmv2-base-uncased/resolve/main/config.json",
"layoutlmv2-large-uncased": "https://huggingface.co/microsoft/layoutlmv2-large-uncased/resolve/main/config.json",
# See all LayoutLMv2 models at https://huggingface.co/models?filter=layoutlmv2
}
# soft dependency
if is_detectron2_available():
import detectron2
class LayoutLMv2Config(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a :class:`~transformers.LayoutLMv2Model`. It is used
to instantiate an LayoutLMv2 model according to the specified arguments, defining the model architecture.
Instantiating a configuration with the defaults will yield a similar configuration to that of the LayoutLMv2
`microsoft/layoutlmv2-base-uncased <https://huggingface.co/microsoft/layoutlmv2-base-uncased>`__ architecture.
Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model
outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.
Args:
vocab_size (:obj:`int`, `optional`, defaults to 30522):
Vocabulary size of the LayoutLMv2 model. Defines the number of different tokens that can be represented by
the :obj:`inputs_ids` passed when calling :class:`~transformers.LayoutLMv2Model` or
:class:`~transformers.TFLayoutLMv2Model`.
hidden_size (:obj:`int`, `optional`, defaults to 768):
Dimension of the encoder layers and the pooler layer.
num_hidden_layers (:obj:`int`, `optional`, defaults to 12):
Number of hidden layers in the Transformer encoder.
num_attention_heads (:obj:`int`, `optional`, defaults to 12):
Number of attention heads for each attention layer in the Transformer encoder.
intermediate_size (:obj:`int`, `optional`, defaults to 3072):
Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
hidden_act (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string,
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"selu"` and :obj:`"gelu_new"` are supported.
hidden_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
The dropout ratio for the attention probabilities.
max_position_embeddings (:obj:`int`, `optional`, defaults to 512):
The maximum sequence length that this model might ever be used with. Typically set this to something large
just in case (e.g., 512 or 1024 or 2048).
type_vocab_size (:obj:`int`, `optional`, defaults to 2):
The vocabulary size of the :obj:`token_type_ids` passed when calling :class:`~transformers.LayoutLMv2Model`
or :class:`~transformers.TFLayoutLMv2Model`.
initializer_range (:obj:`float`, `optional`, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12):
The epsilon used by the layer normalization layers.
max_2d_position_embeddings (:obj:`int`, `optional`, defaults to 1024):
The maximum value that the 2D position embedding might ever be used with. Typically set this to something
large just in case (e.g., 1024).
max_rel_pos (:obj:`int`, `optional`, defaults to 128):
The maximum number of relative positions to be used in the self-attention mechanism.
rel_pos_bins (:obj:`int`, `optional`, defaults to 32):
The number of relative position bins to be used in the self-attention mechanism.
fast_qkv (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to use a single matrix for the queries, keys, values in the self-attention layers.
max_rel_2d_pos (:obj:`int`, `optional`, defaults to 256):
The maximum number of relative 2D positions in the self-attention mechanism.
rel_2d_pos_bins (:obj:`int`, `optional`, defaults to 64):
The number of 2D relative position bins in the self-attention mechanism.
image_feature_pool_shape (:obj:`List[int]`, `optional`, defaults to [7, 7, 256]):
The shape of the average-pooled feature map.
coordinate_size (:obj:`int`, `optional`, defaults to 128):
Dimension of the coordinate embeddings.
shape_size (:obj:`int`, `optional`, defaults to 128):
Dimension of the width and height embeddings.
has_relative_attention_bias (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to use a relative attention bias in the self-attention mechanism.
has_spatial_attention_bias (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to use a spatial attention bias in the self-attention mechanism.
has_visual_segment_embedding (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to add visual segment embeddings.
detectron2_config_args (:obj:`dict`, `optional`):
Dictionary containing the configuration arguments of the Detectron2 visual backbone. Refer to `this file
<https://github.com/microsoft/unilm/blob/master/layoutlmft/layoutlmft/models/layoutlmv2/detectron2_config.py>`__
for details regarding default values.
Example::
>>> from transformers import LayoutLMv2Model, LayoutLMv2Config
>>> # Initializing a LayoutLMv2 microsoft/layoutlmv2-base-uncased style configuration
>>> configuration = LayoutLMv2Config()
>>> # Initializing a model from the microsoft/layoutlmv2-base-uncased style configuration
>>> model = LayoutLMv2Model(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
"""
model_type = "layoutlmv2"
def __init__(
self,
vocab_size=30522,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02,
layer_norm_eps=1e-12,
pad_token_id=0,
max_2d_position_embeddings=1024,
max_rel_pos=128,
rel_pos_bins=32,
fast_qkv=True,
max_rel_2d_pos=256,
rel_2d_pos_bins=64,
convert_sync_batchnorm=True,
image_feature_pool_shape=[7, 7, 256],
coordinate_size=128,
shape_size=128,
has_relative_attention_bias=True,
has_spatial_attention_bias=True,
has_visual_segment_embedding=False,
detectron2_config_args=None,
**kwargs
):
super().__init__(
vocab_size=vocab_size,
hidden_size=hidden_size,
num_hidden_layers=num_hidden_layers,
num_attention_heads=num_attention_heads,
intermediate_size=intermediate_size,
hidden_act=hidden_act,
hidden_dropout_prob=hidden_dropout_prob,
attention_probs_dropout_prob=attention_probs_dropout_prob,
max_position_embeddings=max_position_embeddings,
type_vocab_size=type_vocab_size,
initializer_range=initializer_range,
layer_norm_eps=layer_norm_eps,
pad_token_id=pad_token_id,
**kwargs,
)
self.max_2d_position_embeddings = max_2d_position_embeddings
self.max_rel_pos = max_rel_pos
self.rel_pos_bins = rel_pos_bins
self.fast_qkv = fast_qkv
self.max_rel_2d_pos = max_rel_2d_pos
self.rel_2d_pos_bins = rel_2d_pos_bins
self.convert_sync_batchnorm = convert_sync_batchnorm
self.image_feature_pool_shape = image_feature_pool_shape
self.coordinate_size = coordinate_size
self.shape_size = shape_size
self.has_relative_attention_bias = has_relative_attention_bias
self.has_spatial_attention_bias = has_spatial_attention_bias
self.has_visual_segment_embedding = has_visual_segment_embedding
self.detectron2_config_args = (
detectron2_config_args if detectron2_config_args is not None else self.get_default_detectron2_config()
)
@classmethod
def get_default_detectron2_config(self):
return {
"MODEL.MASK_ON": True,
"MODEL.PIXEL_STD": [57.375, 57.120, 58.395],
"MODEL.BACKBONE.NAME": "build_resnet_fpn_backbone",
"MODEL.FPN.IN_FEATURES": ["res2", "res3", "res4", "res5"],
"MODEL.ANCHOR_GENERATOR.SIZES": [[32], [64], [128], [256], [512]],
"MODEL.RPN.IN_FEATURES": ["p2", "p3", "p4", "p5", "p6"],
"MODEL.RPN.PRE_NMS_TOPK_TRAIN": 2000,
"MODEL.RPN.PRE_NMS_TOPK_TEST": 1000,
"MODEL.RPN.POST_NMS_TOPK_TRAIN": 1000,
"MODEL.POST_NMS_TOPK_TEST": 1000,
"MODEL.ROI_HEADS.NAME": "StandardROIHeads",
"MODEL.ROI_HEADS.NUM_CLASSES": 5,
"MODEL.ROI_HEADS.IN_FEATURES": ["p2", "p3", "p4", "p5"],
"MODEL.ROI_BOX_HEAD.NAME": "FastRCNNConvFCHead",
"MODEL.ROI_BOX_HEAD.NUM_FC": 2,
"MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION": 14,
"MODEL.ROI_MASK_HEAD.NAME": "MaskRCNNConvUpsampleHead",
"MODEL.ROI_MASK_HEAD.NUM_CONV": 4,
"MODEL.ROI_MASK_HEAD.POOLER_RESOLUTION": 7,
"MODEL.RESNETS.DEPTH": 101,
"MODEL.RESNETS.SIZES": [[32], [64], [128], [256], [512]],
"MODEL.RESNETS.ASPECT_RATIOS": [[0.5, 1.0, 2.0]],
"MODEL.RESNETS.OUT_FEATURES": ["res2", "res3", "res4", "res5"],
"MODEL.RESNETS.NUM_GROUPS": 32,
"MODEL.RESNETS.WIDTH_PER_GROUP": 8,
"MODEL.RESNETS.STRIDE_IN_1X1": False,
}
def get_detectron2_config(self):
detectron2_config = detectron2.config.get_cfg()
for k, v in self.detectron2_config_args.items():
attributes = k.split(".")
to_set = detectron2_config
for attribute in attributes[:-1]:
to_set = getattr(to_set, attribute)
setattr(to_set, attributes[-1], v)
return detectron2_config
class LayoutLMv2OnnxConfig(OnnxConfig):
def __init__(
self,
config: PretrainedConfig,
task: str = "default",
patching_specs: List[PatchingSpec] = None,
):
super().__init__(config, task=task, patching_specs=patching_specs)
self.max_2d_positions = config.max_2d_position_embeddings - 1
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict(
[
("input_ids", {0: "batch", 1: "sequence"}),
("bbox", {0: "batch", 1: "sequence"}),
("image", {0:"batch"}),
("attention_mask", {0: "batch", 1: "sequence"}),
("token_type_ids", {0: "batch", 1: "sequence"}),
]
)
def generate_dummy_inputs(
self,
processor: LayoutLMv2Processor,
batch_size: int = -1,
seq_length: int = -1,
is_pair: bool = False,
framework: Optional[TensorType] = None,
) -> Mapping[str, Any]:
"""
Generate inputs to provide to the ONNX exporter for the specific framework
Args:
tokenizer: The tokenizer associated with this model configuration
batch_size: The batch size (int) to export the model for (-1 means dynamic axis)
seq_length: The sequence length (int) to export the model for (-1 means dynamic axis)
is_pair: Indicate if the input is a pair (sentence 1, sentence 2)
framework: The framework (optional) the tokenizer will generate tensor for
is_pair
Returns:
Mapping[str, Tensor] holding the kwargs to provide to the model's forward function
"""
datasets = load_dataset("nielsr/funsd")
labels = datasets['train'].features['ner_tags'].feature.names
example = datasets["test"][0]
# print(example.keys())
image = Image.open(example['image_path'])
image = image.convert("RGB")
if not framework == TensorType.PYTORCH:
raise NotImplementedError("Exporting LayoutLM to ONNX is currently only supported for PyTorch.")
if not is_torch_available():
raise ValueError("Cannot generate dummy inputs without PyTorch installed.")
import torch
input_dict = processor(image, example['words'], boxes=example['bboxes'], word_labels=example['ner_tags'],
return_tensors=framework)
axis = 0
for key_i in input_dict.data.keys():
input_dict.data[key_i] = torch.cat((input_dict.data[key_i], input_dict.data[key_i]), axis)
return input_dict.data
Now when I am trying to run the below code,
processor = LayoutLMv2Processor.from_pretrained("microsoft/layoutlmv2-base-uncased", revision="no_ocr")
model = LayoutLMv2ForTokenClassification.from_pretrained("microsoft/layoutlmv2-base-uncased", torchscript=True)
onnx_config = LayoutLMv2OnnxConfig(model.config)
export(tokenizer=processor, model=model, config=onnx_config, opset=13, output=Path('onnx/layout.onnx'))
I am facing the below error.
Traceback (most recent call last):
File "/home/muhammad/PycharmProjects/js_labs
/Layoutv2/convert_lmv2.py", line 11, in <module>
export(tokenizer=processor, model=model, config=onnx_config, opset=9, output=Path('onnx/layout.onnx'))
File "/home/muhammad/PycharmProjects/js_labs
/anaconda3/envs/onnx-env/lib/python3.7/site-packages/transformers/onnx/convert.py", line 125, in export
opset_version=opset,
File "/home/muhammad/PycharmProjects/js_labs
/anaconda3/envs/onnx-env/lib/python3.7/site-packages/torch/onnx/_init_.py", line 320, in export
custom_opsets, enable_onnx_checker, use_external_data_format)
File "/home/muhammad/PycharmProjects/js_labs
/anaconda3/envs/onnx-env/lib/python3.7/site-packages/torch/onnx/utils.py", line 111, in export
custom_opsets=custom_opsets, use_external_data_format=use_external_data_format)
File "/home/muhammad/PycharmProjects/js_labs
/anaconda3/envs/onnx-env/lib/python3.7/site-packages/torch/onnx/utils.py", line 740, in _export
val_add_node_names, val_use_external_data_format, model_file_location)
RuntimeError: ONNX export failed: Couldn't export operator aten::adaptive_avg_pool2d
One more thing, for dummy input I have provide image as "image", {0:"batch"}
, is this mapping right or do we have to provide image in a different manner.
+1
+1
Hi,
Would be great if you could Google the errors before pinging us (because we at Huggingface are pretty busy). Eg in this case, you can find the answer in the first result on Google: https://github.com/onnx/tutorials/issues/63#issuecomment-559007498
=> The reason is that LayoutLMv2 uses a visual backbone, which includes layers like AdapativeAvgPool2d which aren't supported natively by ONNX.
Hi @NielsRogge , I followed your guide and made the required changes. I updated the pooling layer and now I am faced with the below error. I had googled the previous issue as well but was not kind of sure where to make pooling layer changes. This time I had searched for the subjected issue but to no avail as I am kind of new to to onnx.
Would you please point out where I am making error in the code below.
from transformers.onnx import OnnxConfig, PatchingSpec
from transformers.configuration_utils import PretrainedConfig
from typing import Any, List, Mapping, Optional, Tuple, Union, Iterable
from collections import OrderedDict
from transformers import LayoutLMv2Processor
from datasets import load_dataset
from PIL import Image
import torch
from transformers import PreTrainedModel, TensorType
from torch.onnx import export
from transformers.file_utils import torch_version, is_torch_onnx_dict_inputs_support_available
from pathlib import Path
from transformers.utils import logging
from inspect import signature
from itertools import chain
from transformers import LayoutLMv2ForTokenClassification
from torch import nn
from torch.onnx import OperatorExportTypes
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class LayoutLMv2OnnxConfig(OnnxConfig):
def __init__(
self,
config: PretrainedConfig,
task: str = "default",
patching_specs: List[PatchingSpec] = None,
):
super().__init__(config, task=task, patching_specs=patching_specs)
self.max_2d_positions = config.max_2d_position_embeddings - 1
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict(
[
("input_ids", {0: "batch", 1: "sequence"}),
("bbox", {0: "batch", 1: "sequence"}),
("image", {0: "batch"}),
("attention_mask", {0: "batch", 1: "sequence"}),
("token_type_ids", {0: "batch", 1: "sequence"}),
]
)
def generate_dummy_inputs(
self,
processor: LayoutLMv2Processor,
batch_size: int = -1,
seq_length: int = -1,
is_pair: bool = False,
framework: Optional[TensorType] = None,
) -> Mapping[str, Any]:
datasets = load_dataset("nielsr/funsd")
example = datasets["test"][0]
image = Image.open(example['image_path'])
image = image.convert("RGB")
if not framework == TensorType.PYTORCH:
raise NotImplementedError("Exporting LayoutLM to ONNX is currently only supported for PyTorch.")
input_dict = processor(image, example['words'], boxes=example['bboxes'], word_labels=example['ner_tags'],
return_tensors=framework)
axis = 0
for key_i in input_dict.data.keys():
input_dict.data[key_i] = torch.cat((input_dict.data[key_i], input_dict.data[key_i]), axis)
return input_dict.data
class pool_layer(nn.Module):
def __init__(self):
super(pool_layer, self).__init__()
self.fc = nn.AvgPool2d(kernel_size=[8, 8], stride=[8, 8])
def forward(self, x):
output = self.fc(x)
return output
def ensure_model_and_config_inputs_match(
model: PreTrainedModel, model_inputs: Iterable[str]
) -> Tuple[bool, List[str]]:
"""
:param model:
:param model_inputs:
:return:
"""
forward_parameters = signature(model.forward).parameters
model_inputs_set = set(model_inputs)
# We are fine if config_inputs has more keys than model_inputs
forward_inputs_set = set(forward_parameters.keys())
is_ok = model_inputs_set.issubset(forward_inputs_set)
# Make sure the input order match (VERY IMPORTANT !!!!)
matching_inputs = forward_inputs_set.intersection(model_inputs_set)
ordered_inputs = [parameter for parameter in forward_parameters.keys() if parameter in matching_inputs]
return is_ok, ordered_inputs
def export_model(
processor: LayoutLMv2Processor, model: PreTrainedModel, config: LayoutLMv2OnnxConfig, opset: int, output: Path
) -> Tuple[List[str], List[str]]:
"""
Export a PyTorch backed pipeline to ONNX Intermediate Representation (IR
Args:
processor:
model:
config:
opset:
output:
Returns:
"""
if not is_torch_onnx_dict_inputs_support_available():
raise AssertionError(f"Unsupported PyTorch version, minimum required is 1.8.0, got: {torch_version}")
logger.info(f"Using framework PyTorch: {torch.__version__}")
with torch.no_grad():
model.config.return_dict = True
model.eval()
# Check if we need to override certain configuration item
if config.values_override is not None:
logger.info(f"Overriding {len(config.values_override)} configuration item(s)")
for override_config_key, override_config_value in config.values_override.items():
logger.info(f"\t- {override_config_key} -> {override_config_value}")
setattr(model.config, override_config_key, override_config_value)
model_inputs = config.generate_dummy_inputs(processor, framework=TensorType.PYTORCH)
inputs_match, matched_inputs = ensure_model_and_config_inputs_match(model, model_inputs.keys())
print(matched_inputs)
onnx_outputs = list(config.outputs.keys())
if not inputs_match:
raise ValueError("Model and config inputs doesn't match")
config.patch_ops()
model_inputs.pop("labels")
export(
model,
(model_inputs,),
f=output.as_posix(),
input_names=list(config.inputs.keys()),
output_names=onnx_outputs,
dynamic_axes={name: axes for name, axes in chain(config.inputs.items(), config.outputs.items())},
do_constant_folding=True,
use_external_data_format=config.use_external_data_format(model.num_parameters()),
enable_onnx_checker=True,
opset_version=opset,
# operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK
)
config.restore_ops()
return matched_inputs, onnx_outputs
if __name__ == '__main__':
processor = LayoutLMv2Processor.from_pretrained("microsoft/layoutlmv2-base-uncased", revision="no_ocr")
model = LayoutLMv2ForTokenClassification.from_pretrained("microsoft/layoutlmv2-base-uncased", torchscript = True)
model.layoutlmv2.visual.pool = torch.nn.Sequential(pool_layer())
onnx_config = LayoutLMv2OnnxConfig(model.config)
export_model(processor=processor, model=model, config=onnx_config, opset=13, output=Path('onnx/layout.onnx'))
Running the above code is raising the below error,
RuntimeError Traceback (most recent call last)
<ipython-input-6-134631b21e61> in <module>()
168 model.layoutlmv2.visual.pool = torch.nn.Sequential(pool_layer())
169 onnx_config = LayoutLMv2OnnxConfig(model.config)
--> 170 export_model(processor=processor, model=model, config=onnx_config, opset=13, output=Path('onnx/layout.onnx'))
4 frames
<ipython-input-6-134631b21e61> in export_model(processor, model, config, opset, output)
154 use_external_data_format=config.use_external_data_format(model.num_parameters()),
155 enable_onnx_checker=True,
--> 156 opset_version=opset,
157 # operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK
158 )
/usr/local/lib/python3.7/dist-packages/torch/onnx/__init__.py in export(model, args, f, export_params, verbose, training, input_names, output_names, aten, export_raw_ir, operator_export_type, opset_version, _retain_param_name, do_constant_folding, example_outputs, strip_doc_string, dynamic_axes, keep_initializers_as_inputs, custom_opsets, enable_onnx_checker, use_external_data_format)
274 do_constant_folding, example_outputs,
275 strip_doc_string, dynamic_axes, keep_initializers_as_inputs,
--> 276 custom_opsets, enable_onnx_checker, use_external_data_format)
277
278
/usr/local/lib/python3.7/dist-packages/torch/onnx/utils.py in export(model, args, f, export_params, verbose, training, input_names, output_names, aten, export_raw_ir, operator_export_type, opset_version, _retain_param_name, do_constant_folding, example_outputs, strip_doc_string, dynamic_axes, keep_initializers_as_inputs, custom_opsets, enable_onnx_checker, use_external_data_format)
92 dynamic_axes=dynamic_axes, keep_initializers_as_inputs=keep_initializers_as_inputs,
93 custom_opsets=custom_opsets, enable_onnx_checker=enable_onnx_checker,
---> 94 use_external_data_format=use_external_data_format)
95
96
/usr/local/lib/python3.7/dist-packages/torch/onnx/utils.py in _export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, export_type, example_outputs, opset_version, _retain_param_name, do_constant_folding, strip_doc_string, dynamic_axes, keep_initializers_as_inputs, fixed_batch_size, custom_opsets, add_node_names, enable_onnx_checker, use_external_data_format, onnx_shape_inference, use_new_jit_passes)
696 training=training,
697 use_new_jit_passes=use_new_jit_passes,
--> 698 dynamic_axes=dynamic_axes)
699
700 # TODO: Don't allocate a in-memory string for the protobuf
/usr/local/lib/python3.7/dist-packages/torch/onnx/utils.py in _model_to_graph(model, args, verbose, input_names, output_names, operator_export_type, example_outputs, _retain_param_name, do_constant_folding, _disable_torch_constant_prop, fixed_batch_size, training, use_new_jit_passes, dynamic_axes)
498 if do_constant_folding and _export_onnx_opset_version in torch.onnx.constant_folding_opset_versions:
499 params_dict = torch._C._jit_pass_onnx_constant_fold(graph, params_dict,
--> 500 _export_onnx_opset_version)
501 torch._C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph)
502
RuntimeError: Tensors must have same number of dimensions: got 2 and 1
@fadi212 Have you tried using another opset
version, such as 11?
Speaking from complete ignorance here, but maybe worth a try :)
my model is converted to onnx but at time of loading model to onnxruntime I am getting below error. Type Error: Type parameter (T) bound to different types (tensor(double) and tensor(float) in node ()
@michaelbenayoun @wilbry @fadi212
Hi,
Can you check out the solution provided here?
Also, if you managed to convert the model to ONNX, feel free to open a PR which we can review, it will benefit the community a lot.
Thanks!
Hi @lalitr994 , I did not face this error. I was able to convert my model to onnx and loading and predicting correctly. I am working on creating a PR but facing some issues as the conversion process for this model is a bit different than others.
@fadi212 can you share your repo. how you have converted and loaded onnx model to onnx runtime? I am stucked at loading model to run time.
Hi @lalitr994 , You can use this script to convert the code for now. ` from transformers.onnx import OnnxConfig, PatchingSpec from transformers.configuration_utils import PretrainedConfig from typing import Any, List, Mapping, Optional, Tuple, Iterable from collections import OrderedDict from transformers import LayoutLMv2Processor from datasets import load_dataset from PIL import Image import torch from transformers import PreTrainedModel, TensorType from torch.onnx import export from transformers.file_utils import torch_version, is_torch_onnx_dict_inputs_support_available from pathlib import Path from transformers.utils import logging from inspect import signature from itertools import chain from transformers import LayoutLMv2ForTokenClassification from torch import nn from torch.onnx import OperatorExportTypes
logger = logging.get_logger(name) # pylint: disable=invalid-name
class LayoutLMv2OnnxConfig(OnnxConfig): def init( self, config: PretrainedConfig, task: str = "default", patching_specs: List[PatchingSpec] = None, ): super().init(config, task=task, patching_specs=patching_specs)
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict(
[
("input_ids", {0: "batch", 1: "sequence"}),
("bbox", {0: "batch", 1: "sequence"}),
("image", {0: "batch"}),
("attention_mask", {0: "batch", 1: "sequence"}),
("token_type_ids", {0: "batch", 1: "sequence"}),
]
)
@property
def outputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict(
[
# ("loss", {}),
("logits", {0: "batch", 1: "sequence"}),
# ("hidden_states", {}),
# ("attentions", {})
]
)
def generate_dummy_inputs(
self,
processor: LayoutLMv2Processor,
batch_size: int = -1,
seq_length: int = -1,
is_pair: bool = False,
framework: Optional[TensorType] = None,
) -> Mapping[str, Any]:
# datasets = load_dataset("nielsr/funsd")
# example = datasets["test"][0]
# image = Image.open(example['image_path'])
# image = image.convert("RGB")
if not framework == TensorType.PYTORCH:
raise NotImplementedError("Exporting LayoutLM to ONNX is currently only supported for PyTorch.")
# input_dict = processor(image, example['words'], boxes=example['bboxes'], word_labels=example['ner_tags'],
# return_tensors=framework)
# axis = 0
# for key_i in input_dict.data.keys():
# input_dict.data[key_i] = torch.cat((input_dict.data[key_i], input_dict.data[key_i]), axis)
return dict(
input_ids=torch.zeros((2, 8), dtype=torch.int64),
token_type_ids=torch.zeros((2, 8), dtype=torch.int64),
attention_mask=torch.zeros((2, 8), dtype=torch.float),
bbox=torch.zeros((2, 8, 4), dtype=torch.int64),
labels=torch.zeros((2, 8), dtype=torch.int64),
image=torch.zeros((2, 3, 224, 224), dtype=torch.int64),
)
class pool_layer(nn.Module): def init(self): super(pool_layer, self).init() self.pool = nn.AvgPool2d(kernel_size=[8, 8], stride=[8, 8])
def forward(self, x):
output = self.pool(x)
return output
def ensure_model_and_config_inputs_match( model: PreTrainedModel, model_inputs: Iterable[str] ) -> Tuple[bool, List[str]]: """
:param model:
:param model_inputs:
:return:
"""
forward_parameters = signature(model.forward).parameters
model_inputs_set = set(model_inputs)
# We are fine if config_inputs has more keys than model_inputs
forward_inputs_set = set(forward_parameters.keys())
is_ok = model_inputs_set.issubset(forward_inputs_set)
# Make sure the input order match (VERY IMPORTANT !!!!)
matching_inputs = forward_inputs_set.intersection(model_inputs_set)
ordered_inputs = [parameter for parameter in forward_parameters.keys() if parameter in matching_inputs]
return is_ok, ordered_inputs
def export_model( processor: LayoutLMv2Processor, model: PreTrainedModel, config: LayoutLMv2OnnxConfig, opset: int, output: Path ) -> Tuple[List[str], List[str]]: """ Export a PyTorch backed pipeline to ONNX Intermediate Representation (IR
Args:
processor:
model:
config:
opset:
output:
Returns:
"""
if not is_torch_onnx_dict_inputs_support_available():
raise AssertionError(f"Unsupported PyTorch version, minimum required is 1.8.0, got: {torch_version}")
# logger.info(f"Using framework PyTorch: {torch.__version__}")
with torch.no_grad():
model.config.return_dict = True
model.eval()
# Check if we need to override certain configuration item
if config.values_override is not None:
logger.info(f"Overriding {len(config.values_override)} configuration item(s)")
for override_config_key, override_config_value in config.values_override.items():
logger.info(f"\t- {override_config_key} -> {override_config_value}")
setattr(model.config, override_config_key, override_config_value)
model_inputs = config.generate_dummy_inputs(processor, framework=TensorType.PYTORCH)
inputs_match, matched_inputs = ensure_model_and_config_inputs_match(model, model_inputs.keys())
onnx_outputs = list(config.outputs.keys())
if not inputs_match:
raise ValueError("Model and config inputs doesn't match")
model_inputs.pop("labels")
config.patch_ops()
export(
model,
(model_inputs,),
f=output.as_posix(),
input_names=list(config.inputs.keys()),
output_names=onnx_outputs,
dynamic_axes={name: axes for name, axes in chain(config.inputs.items(), config.outputs.items())},
do_constant_folding=True,
use_external_data_format=config.use_external_data_format(model.num_parameters()),
enable_onnx_checker=True,
opset_version=opset,
verbose=True
# operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK
)
config.restore_ops()
return matched_inputs, onnx_outputs
if name == 'main': processor = LayoutLMv2Processor.from_pretrained("microsoft/layoutlmv2-base-uncased", revision="no_ocr") model = LayoutLMv2ForTokenClassification.from_pretrained("microsoft/layoutlmv2-base-uncased", num_labels=7) model.layoutlmv2.visual.pool = pool_layer() onnx_config = LayoutLMv2OnnxConfig(model.config) export_model(processor=processor, model=model, config=onnx_config, opset=13, output=Path('onnx/layout2.onnx'))
`
Also you will have to make the change these lines in modeling_layoutlmv2 in transformers library. `
visual_shape = deepcopy(list(input_shape)) #line 859
visual_shape[1] = self.config.image_feature_pool_shape[0] * self.config.image_feature_pool_shape[1]
visual_shape = torch.Size(visual_shape)
final_shape = deepcopy(list(input_shape)) #line 862
final_shape[1] += visual_shape[1]
final_shape = torch.Size(final_shape)
`
Hi @fadi212! Thanks for your script!
I'm having some trouble exporting the microsoft/layoutlmv2-base-uncased
model (just testing it works ok before exporting my model). I have discarded any errors in your code, as it works perfectly, but it ends up failing with a segmentation fault deep into some pytorch
C bindings.
May I ask you what versions of the libraries have you installed, in particular pytorch
and onnx
?
Just for the record, the segfault happens consistently at line 218 of the picture, which is located inside an optimization routine called _optimize_graph
at torch.onnx.utils
.
Interestingly, by explicitly setting the operator_export_type
to OperatorExportTypes.ONNX_ATEN
on the export
function, it manages to go through that line, but fails again a little further down at line 238 of the picture, albeit without a segfault (just a regular Python exception Traceback):
I think I have narrowed down the problem to the generation of invalid ONNX code (in particular, some UNKNOWN_SCALAR
s), most likely due to some unsupported operation similar to the AdaptiveAvgPool2d
-> AvgPool2d
issue.
Hi @viantirreau @lalitr994 , You can take a look at this PR and convert your model with this branch. https://github.com/huggingface/transformers/pull/14555
Thanks @fadi212 I will try my model with this brach
In Onnx conversion I got warning like /torch/onnx/symbolic_helper.py:258: UserWarning: ONNX export failed on adaptive_avg_pool2d because input size not accessible not supported warnings.warn("ONNX export failed on " + op + " because " + msg + " not supported") Warning: Shape inference does not support models with experimental operators: ATen
During infer from model I got the error below onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Load model from onnx_model.onnx failed:Fatal error: adaptive_avg_pool2d is not a registered function/op
@fadi212 I followed your code but facing this issue.
Hi @riqui-puig, I have created a PR to add support for LayoutLMv2 you can use that. https://github.com/huggingface/transformers/pull/14555
The code is not merged yet but you can install that particualr branch and then you can convert your model using command line.
Hi @fadi212 I have tried to convert LayoutLMv2 Q-A model into onnx but still showing errors. Could you please guide here. Thanks in advance.
Command: !python -m transformers.onnx --model=microsoft/layoutlmv2-base-uncased onnx/
Error log: Some weights of the model checkpoint at microsoft/layoutlmv2-base-uncased were not used when initializing LayoutLMv2Model: ['layoutlmv2.visual.backbone.bottom_up.res4.15.conv3.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.0.shortcut.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.10.conv1.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.2.conv2.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.17.conv3.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res5.2.conv3.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.22.conv2.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.14.conv2.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.8.conv3.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.17.conv2.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.14.conv1.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.19.conv2.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res2.0.conv3.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res2.2.conv1.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.7.conv1.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res3.0.conv1.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.12.conv2.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.9.conv3.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res2.1.conv1.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res5.0.shortcut.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.2.conv1.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.11.conv2.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res3.3.conv1.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.14.conv3.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res5.1.conv1.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.13.conv3.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.2.conv3.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.15.conv2.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.12.conv1.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.5.conv1.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res3.1.conv2.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res3.1.conv3.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res5.1.conv3.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res2.0.shortcut.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res5.2.conv2.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.16.conv2.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res5.0.conv2.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.8.conv2.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.21.conv3.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res3.1.conv1.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res5.0.conv1.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res5.1.conv2.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.18.conv1.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.3.conv1.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.15.conv1.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.10.conv2.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.3.conv3.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.21.conv1.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.10.conv3.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.7.conv2.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.9.conv2.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.4.conv2.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.0.conv2.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.1.conv2.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.3.conv2.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.20.conv2.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res2.0.conv1.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.0.conv1.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.11.conv3.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res3.2.conv2.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.20.conv1.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res2.2.conv3.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.stem.conv1.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.8.conv1.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.5.conv3.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res3.0.conv3.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.6.conv2.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.16.conv3.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.22.conv3.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.18.conv2.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.19.conv3.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res3.3.conv2.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.4.conv1.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.13.conv2.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.17.conv1.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res3.3.conv3.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.1.conv1.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.6.conv1.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res3.2.conv3.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.19.conv1.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.0.conv3.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res3.2.conv1.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.12.conv3.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.5.conv2.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.18.conv3.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.11.conv1.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res2.0.conv2.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res2.1.conv3.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.1.conv3.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.13.conv1.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res3.0.shortcut.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res5.2.conv1.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.20.conv3.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.9.conv1.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.4.conv3.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.22.conv1.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res3.0.conv2.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.6.conv3.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.16.conv1.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res2.1.conv2.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res5.0.conv3.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.7.conv3.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res2.2.conv2.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.21.conv2.norm.num_batches_tracked']
Has this issue been resolved? If not then I would like to work on it.
@avisinghal6 ONNX support is now handled by the optimum library, model configs can be found here
I am trying to export LayoutLMv2 model to onnx but there is no support for that available in transformers library. I have tried to follow the method available for layoutLM but that is not working. Here is config class for LayoutLMv2
Running the export line is raising this error,