Closed nirmal2k closed 2 years ago
I tried running the same code readme using facebook/bart-large-mnli. I am getting KeyError: 'decoder_input_ids'
KeyError: 'decoder_input_ids'
from optimum.onnxruntime.configuration import OptimizationConfig optimization_config = OptimizationConfig(optimization_level=99, optimize_for_gpu=True) from optimum.onnxruntime import ORTOptimizer model_name = "facebook/bart-large-mnli" optimizer = ORTOptimizer.from_pretrained( model_name, feature="sequence-classification", ) optimizer.export( onnx_model_path="op_fbbart.onnx", onnx_optimized_model_output_path="op_fbbart-optimized.onnx", optimization_config=optimization_config, ) from optimum.onnxruntime import ORTModel from functools import partial from datasets import Dataset ort_model = ORTModel("op_fbbart-optimized.onnx", optimizer._onnx_config) ds = Dataset.from_dict({"sentence": ["I love burritos!"]}) def preprocess_fn(ex, tokenizer): return tokenizer(ex["sentence"]) tokenized_ds = ds.map(partial(preprocess_fn, tokenizer=optimizer.tokenizer)) ort_outputs = ort_model.evaluation_loop(tokenized_ds) ort_outputs.predictions #Key Error
Traceback:
KeyError Traceback (most recent call last) Input In [13], in <cell line: 2>() 1 tokenized_ds = ds.map(partial(preprocess_fn, tokenizer=optimizer.tokenizer)) ----> 2 ort_outputs = ort_model.evaluation_loop(tokenized_ds) 3 # Extract logits! 4 ort_outputs.predictions File ~/miniconda3/envs/optimum/lib/python3.9/site-packages/optimum/onnxruntime/model.py:93, in ORTModel.evaluation_loop(self, dataset) 91 else: 92 labels = None ---> 93 onnx_inputs = {key: np.array([inputs[key]]) for key in self.onnx_config.inputs} 94 preds = session.run(self.onnx_named_outputs, onnx_inputs) 95 if len(preds) == 1: File ~/miniconda3/envs/optimum/lib/python3.9/site-packages/optimum/onnxruntime/model.py:93, in <dictcomp>(.0) 91 else: 92 labels = None ---> 93 onnx_inputs = {key: np.array([inputs[key]]) for key in self.onnx_config.inputs} 94 preds = session.run(self.onnx_named_outputs, onnx_inputs) 95 if len(preds) == 1: KeyError: 'decoder_input_ids'
Hi @nirmal2k , the issue is coming from the fact that the key 'decoder_input_ids' present in self.onnx_config.inputs does not exist in the given dataset. This was fixed in PR#126, thanks for reporting it.
'decoder_input_ids'
self.onnx_config.inputs
I tried running the same code readme using facebook/bart-large-mnli. I am getting
KeyError: 'decoder_input_ids'
Traceback: