Open patricianing opened 9 months ago
Also tried passing the vocab file, no difference onnx_tokenizer = pnp.PreHuggingFaceBert(vocab_file='./all-mpnet-base-v2/vocab.txt')
I think this could be fixed if the tokenizer constructor was modified here to pull out the following variables:
self.onnx_bert_tokenizer = create_op_function('BertTokenizer', bert_tokenizer,
hf_tok=hf_tok,
sep_token=hf_tok.eos_token,
cls_token=hf_tok.bos_token,
pad_token=hf_tok.pad_token)
As it defaults all those fields here but in the case of all-mpnet-base-v2
those defaults are wrong.
Certain fields in the tokenizer was not checked when exporting with onnxruntime-extension pnp module, causing a mismatch for cls_token and sep_token.
code showing the difference
import onnxruntime import onnxruntime_extensions import numpy as np
from onnxruntime_extensions import pnp from transformers import AutoTokenizer
def map_token_output(input_ids, attention_mask, token_type_ids): return input_ids.unsqueeze(0), token_type_ids.unsqueeze(0), attention_mask.unsqueeze(0)
model_name = 'sentence-transformers/all-mpnet-base-v2' output_name = 'all-mpnet-base-v2-aug.onnx'
symbolic = {0: 'batch_size', 1: 'sequence_length'}
tokenizer = AutoTokenizer.from_pretrained(model_name) bert_tokenizer = pnp.PreHuggingFaceBert(hf_tok=tokenizer)
augmented_model = pnp.SequentialProcessingModule(bert_tokenizer, map_token_output)
test_input = ["This is a test sentence"]
augmented_model = pnp.export(augmented_model, test_input, opset_version=12, input_names=['input'], output_names=['input_ids', 'attention_mask', 'token_type_ids'], output_path=output_name, dynamic_axes={'input_ids': symbolic, 'attention_mask': symbolic, 'token_type_ids': symbolic})
session_options = onnxruntime.SessionOptions() session_options.register_custom_ops_library(onnxruntime_extensions.get_library_path()) session = onnxruntime.InferenceSession(output_name, session_options) results = session.run([], {"input": test_input})
encoded_input = tokenizer(test_input, padding=True, truncation=True, return_tensors='pt') np.testing.assert_allclose(encoded_input.get('input_ids'), results[0], rtol=1e-04, atol=1e-05)