Open PhilippShemetov opened 1 year ago
@yufenglee any insights?
I forgot to mention computer specifications:
Also! I did another experiment on a different device (laptop). Laptop specifications:
Results: Perfomance for first model (1 Conv):
ORT::ENABLE_BASIC
or ORT::ENABLE_EXTENDED
: Average - 3.38 ms, median - 3.32, std - 0.34 msORT::ENABLE_ALL
: Average - 2.86 ms, median - 2.95, std - 0.46 msPerfomance for second model (1 Conv + Max-Pool):
ORT::ENABLE_BASIC
or ORT::ENABLE_EXTENDED
: Average - 4.05 ms, median - 3.88, std - 0.38 msORT::ENABLE_ALL
: Average - 2.55 ms, median - 2.52, std - 0.33 ms@PhilippShemetov I would suggest you to collect more data using the following methods. Maybe this can provide more insights.
session_options.SetOptimizedModelFilePath
, as shown in this doc. This will give you a chance to inspect the final model that is executed on your machine.session_options.EnableProfiling
, as shown in this doc. You can then visualize the profiling data (a JSON file) with various tools, e.g. edge://tracing in the Edge browser. This will give you the exact execution duration of each operator. Note that the operators here match with the operators in the optimized graph, not the original model graph.If I have to guess, the reason that Conv+MaxPool is faster is probably that the result tensor in your Conv+MaxPool model is much smaller than the Conv-only model. In fact it's only 1/4 as large, because you are using a torch.nn.MaxPool2d
with kernel size 2x2. There is a hidden operation after your Conv+MaxPool because ONNXRuntime needs to convert the result tensor from NCHWc format back to the normal NCHW format. This memory layout conversion is faster in Conv+MaxPool model because it is operating on a smaller tensor. There is also a memory layout conversion before the first Conv operator, but this one is operating on the same tensor (the input data) in both models, so it does not contribute to the performance difference you observed.
My guess could be wrong as I don't have the optimized graph or the profiling data of your models. You can collect the additional data and verify it.
Disclaimer: I am not a developer of ONNXRuntime.
Describe the issue
I am using ONNX Runtime for inference on two different models (for example). The first model consists of just one convolution layer, while the second model has one convolution layer followed by a max-pooling layer. I noticed that the inference time for the second model (Conv + Max-Pool) is faster than the first model (Conv) when I use
ORT::ENABLE_ALL
. I would like to understand why this is the case, because if i'm usingORT:DISABLE_ALL
,ORT::ENABLE_BASIC
orORT::ENABLE_EXTENDED
models show that Conv is faster than Conv + Max-Pool. I would like to understand why the second model (Conv + Max-Pool) has a faster inference time than the first model (Conv) when usingORT::ENABLE_ALL
, even though it should generally have more FLOPs compared to the model with only one convolution layer. Why does changing NCHW to NCHWc affect performance? Is there any specific reason why the additional max-pooling layer leads to better performance in this specific scenario?Any insights or explanations about this behavior would be appreciated. Thank you
The models were converted to ONNX using
torch.onnx.export
with opt13Test data:
Perfomance for first model (1 Conv):
ORT::ENABLE_BASIC
orORT::ENABLE_EXTENDED
: Average - 3.03 ms, median - 2.93, std - 0.1 msORT::ENABLE_ALL
: Average - 2.30 ms, median - 2.30, std - 0.1 msPerfomance for first model (1 Conv + Max-Pool):
ORT::ENABLE_BASIC
orORT::ENABLE_EXTENDED
: Average - 3.54 ms, median - 3.54, std - 0.1 msORT::ENABLE_ALL
: Average - 2.14 ms, median - 2.13, std - 0.15 msTo reproduce
Firts model:
Second model:
Urgency
No response
Platform
Linux
OS Version
Ubuntu 20.04
ONNX Runtime Installation
Built from Source
ONNX Runtime Version or Commit ID
1.14.0
ONNX Runtime API
C++
Architecture
X64
Execution Provider
Default CPU
Execution Provider Library Version
No response
Model File
No response
Is this a quantized model?
No