huggingface / optimum

🚀 Accelerate training and inference of 🤗 Transformers and 🤗 Diffusers with easy to use hardware optimization tools
https://huggingface.co/docs/optimum/main/
Apache License 2.0
2.54k stars 454 forks source link

How to optimize bart models? #123

Closed nirmal2k closed 2 years ago

nirmal2k commented 2 years ago

I tried running the same code readme using facebook/bart-large-mnli. I am getting KeyError: 'decoder_input_ids'

from optimum.onnxruntime.configuration import OptimizationConfig
optimization_config = OptimizationConfig(optimization_level=99, optimize_for_gpu=True)

from optimum.onnxruntime import ORTOptimizer
model_name = "facebook/bart-large-mnli"
optimizer = ORTOptimizer.from_pretrained(
    model_name,
    feature="sequence-classification",
)

optimizer.export(
    onnx_model_path="op_fbbart.onnx",
    onnx_optimized_model_output_path="op_fbbart-optimized.onnx",
    optimization_config=optimization_config,
)

from optimum.onnxruntime import ORTModel
from functools import partial
from datasets import Dataset

ort_model = ORTModel("op_fbbart-optimized.onnx", optimizer._onnx_config)
ds = Dataset.from_dict({"sentence": ["I love burritos!"]})
def preprocess_fn(ex, tokenizer):
    return tokenizer(ex["sentence"])

tokenized_ds = ds.map(partial(preprocess_fn, tokenizer=optimizer.tokenizer))
ort_outputs = ort_model.evaluation_loop(tokenized_ds)
ort_outputs.predictions  #Key Error

Traceback:

KeyError                                  Traceback (most recent call last)
Input In [13], in <cell line: 2>()
      1 tokenized_ds = ds.map(partial(preprocess_fn, tokenizer=optimizer.tokenizer))
----> 2 ort_outputs = ort_model.evaluation_loop(tokenized_ds)
      3 # Extract logits!
      4 ort_outputs.predictions

File ~/miniconda3/envs/optimum/lib/python3.9/site-packages/optimum/onnxruntime/model.py:93, in ORTModel.evaluation_loop(self, dataset)
     91 else:
     92     labels = None
---> 93 onnx_inputs = {key: np.array([inputs[key]]) for key in self.onnx_config.inputs}
     94 preds = session.run(self.onnx_named_outputs, onnx_inputs)
     95 if len(preds) == 1:

File ~/miniconda3/envs/optimum/lib/python3.9/site-packages/optimum/onnxruntime/model.py:93, in <dictcomp>(.0)
     91 else:
     92     labels = None
---> 93 onnx_inputs = {key: np.array([inputs[key]]) for key in self.onnx_config.inputs}
     94 preds = session.run(self.onnx_named_outputs, onnx_inputs)
     95 if len(preds) == 1:

KeyError: 'decoder_input_ids'
echarlaix commented 2 years ago

Hi @nirmal2k , the issue is coming from the fact that the key 'decoder_input_ids' present in self.onnx_config.inputs does not exist in the given dataset. This was fixed in PR#126, thanks for reporting it.