microsoft / onnxruntime-extensions

onnxruntime-extensions: A specialized pre- and post- processing library for ONNX Runtime
MIT License
323 stars 84 forks source link

exported huggingface tokenizer generates different results #623

Open patricianing opened 9 months ago

patricianing commented 9 months ago

Certain fields in the tokenizer was not checked when exporting with onnxruntime-extension pnp module, causing a mismatch for cls_token and sep_token.

code showing the difference

import onnxruntime import onnxruntime_extensions import numpy as np

from onnxruntime_extensions import pnp from transformers import AutoTokenizer

def map_token_output(input_ids, attention_mask, token_type_ids): return input_ids.unsqueeze(0), token_type_ids.unsqueeze(0), attention_mask.unsqueeze(0)

model_name = 'sentence-transformers/all-mpnet-base-v2' output_name = 'all-mpnet-base-v2-aug.onnx'

symbolic = {0: 'batch_size', 1: 'sequence_length'}

tokenizer = AutoTokenizer.from_pretrained(model_name) bert_tokenizer = pnp.PreHuggingFaceBert(hf_tok=tokenizer)

augmented_model = pnp.SequentialProcessingModule(bert_tokenizer, map_token_output)

test_input = ["This is a test sentence"]

augmented_model = pnp.export(augmented_model, test_input, opset_version=12, input_names=['input'], output_names=['input_ids', 'attention_mask', 'token_type_ids'], output_path=output_name, dynamic_axes={'input_ids': symbolic, 'attention_mask': symbolic, 'token_type_ids': symbolic})

session_options = onnxruntime.SessionOptions() session_options.register_custom_ops_library(onnxruntime_extensions.get_library_path()) session = onnxruntime.InferenceSession(output_name, session_options) results = session.run([], {"input": test_input})

encoded_input = tokenizer(test_input, padding=True, truncation=True, return_tensors='pt') np.testing.assert_allclose(encoded_input.get('input_ids'), results[0], rtol=1e-04, atol=1e-05)

patricianing commented 9 months ago

Also tried passing the vocab file, no difference onnx_tokenizer = pnp.PreHuggingFaceBert(vocab_file='./all-mpnet-base-v2/vocab.txt')

Craigacp commented 9 months ago

I think this could be fixed if the tokenizer constructor was modified here to pull out the following variables:

            self.onnx_bert_tokenizer = create_op_function('BertTokenizer', bert_tokenizer,
                                                          hf_tok=hf_tok,
                                                          sep_token=hf_tok.eos_token,
                                                          cls_token=hf_tok.bos_token,
                                                          pad_token=hf_tok.pad_token)

As it defaults all those fields here but in the case of all-mpnet-base-v2 those defaults are wrong.