microsoft / onnxruntime

ONNX Runtime: cross-platform, high performance ML inferencing and training accelerator
https://onnxruntime.ai
MIT License
14.5k stars 2.9k forks source link

ONNX Exported BART Model Performance is degraded than native pytorch on T4 #7796

Open anshoomehra opened 3 years ago

anshoomehra commented 3 years ago

Describe the bug @hariharans29 creating new issue as your suggestion.

Stemming from Issue, post enabling CUDAExecutionProvider, the performance of inference seems to have degraded post ONNX conversion.

Pytorch Native Model Performance : 676 ms ONNX Model Peformance: 2.78 sec

Urgency Our project went live this weekend, but the performance is hammering us.

System information

To Reproduce Attached full script/jupyter notebook to reproduce and analyze. Please look at cell #7 onwards. bart_onnx-am.ipynb.zip

Expected behavior Performance be significantly better than native pytorch

wangyems commented 3 years ago

Thanks @anshoomehra for providing the repro! I'll take a look

wangyems commented 3 years ago

Hi @anshoomehra, based on your code there are two things you can do to improve the performance:

  1. eliminate the multiple rounds of data copy from host to device when feeding the decoder_inputs. you can take a look at https://www.onnxruntime.ai/python/api_summary#iobinding
  2. apply optimizer to each of the onnx model: https://github.com/microsoft/onnxruntime/tree/master/onnxruntime/python/tools/transformers#model-optimizer and using 'bert' as model_type
tianleiwu commented 3 years ago

@anshoomehra

For #1, you can refer to example of GPT-2 to keep inputs and outputs in GPU during text generation (all intermediate tensors shall be in GPU, and not copied to CPU) : https://github.com/microsoft/onnxruntime/blob/a41255c280b06a7de748f211af42abccc724a225/onnxruntime/python/tools/transformers/gpt2_helper.py#L486-L502

Example change for #2 (optimizer): test_bart.zip

Let us know whether the issue can be resolved using these two changes.

anshoomehra commented 3 years ago

@wangyems & @tianleiwu truly appreciate the inputs & revised code. Working on changing the variable bindings and understanding the optimized code. I tried running code with optimization set as 'BERT' the model fails with below error, the model we are looking here is BART (Text Generation, producing summaries) not BERT. Is BERT the right choice? If not, is there support for BART? If not, can we instead use gpt-x being close match?

image

sam-writer commented 2 years ago

I am curious about this... we're using T5 which is similar to BART

tianleiwu commented 2 years ago

@sam-writer, the merged PR #8698 shall improve BART performance. For T5, the DecoderAttention operator need slight change.

stale[bot] commented 2 years ago

This issue has been automatically marked as stale due to inactivity and will be closed in 7 days if no further activity occurs. If further support is needed, please provide an update and/or more details.