microsoft / onnxruntime

ONNX Runtime: cross-platform, high performance ML inferencing and training accelerator
https://onnxruntime.ai
MIT License
14.43k stars 2.9k forks source link

[Performance] ONNX Runtime doesn't parallelize operations in CPU models #16158

Open nick-griaznov opened 1 year ago

nick-griaznov commented 1 year ago

Describe the issue

Since 1.14 release ONNX Runtime seems to stop supporting parallelization of operations in graphs on CPU. SessionOptions parameters such as inter_op_num_threads and execution_mode don't affect performance on highly parallel graphs at all. This is contrary to previous ONNX Runtime version 1.13, where inter_op_num_threads value strongly correlated with performance. Therefore performance of new ORT version can be several times slower depending on number of threads and degree of graph parallelism.

I suspect that cause of this problem is multi-stream Execution Provider refactoring that happened in 1.14. Prior to this refactoring there was a Parallel Executor that was able to parallelize operations in graph. But current implementation seems to merge all CPU nodes into one sequential chunk.

I attached sample model and program that executes a pretty simple but highly parallel graph. Depending on ONNX Runtime version (before 1.14 or after) execution time with 10 threads is significantly different (4x difference on my machine).

To reproduce

import numpy as np
import onnxruntime as ort
import timeit

data = np.ones(1000000, np.int64)

def measure_time(n_threads, iters=30):
    sess_opts = ort.SessionOptions()
    sess_opts.enable_cpu_mem_arena = True
    sess_opts.inter_op_num_threads = n_threads
    sess_opts.intra_op_num_threads = 1
    sess_opts.execution_mode = ort.ExecutionMode.ORT_PARALLEL
    sess = ort.InferenceSession("parallel_onnx_model.onnx", sess_opts)
    total_time = timeit.timeit(
        lambda: sess.run(["out"], {"x": data}), number=iters, globals=globals()
    )
    avg_time = total_time / iters
    print(f"Average time with {n_threads} threads: {avg_time}")

measure_time(n_threads=1)
measure_time(n_threads=10)

Urgency

This is a clear functionality regression. We used ONNX Runtime primarily for executing graphs on CPU and now we experience a major performance degradation on our models

Platform

Linux

OS Version

Ubuntu 20.04.6 LTS

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

1.14.1

ONNX Runtime API

Python

Architecture

X64

Execution Provider

Default CPU

Execution Provider Library Version

No response

Model File

parallel_onnx_model.onnx.zip

Is this a quantized model?

No

pranavsharma commented 1 year ago

Yes, you're absolutely right. We got rid of the parallel executor when we introduced the multi-stream change because we received reports of deadlocks and other issues. Does your graph have many parallel paths that lends itself to parallel paths of execution?

nick-griaznov commented 1 year ago

We have models with multiple independent heads, hence they can be can be completely parallelized. I think these heads have parallel paths in them as well

Regarding deadlocks, we previously had a forked version of ONNX Runtime, where we modified ParallelExecutor to use atomics instead of mutexes as much as possible. This implementation worked faster, but initially caused deadlocks. It turned out that problem was not in modified ParallelExecutor but in current implementation of EigenNonBlockingThreadPool. There is a bug in synchronization between threads. This bug was probably introduced when code was taken from Eigen repo and modified. We fixed this bug in our fork and had not experienced deadlocks any more. If you have any plans of returning ParallelExecutor we would be happy to share our patches that improve performance of ParallelExecutor and fix synchronization bug in EigenNonBlockingThreadPool

adityagoel4512 commented 1 year ago

We would also appreciate reintroducing the ParallelExecutor.

RitheshKumar commented 6 months ago

Hi, checking in again. Is parallel execution still disabled? I'm using the Java ONNX Runtime API for running a model on an Android Device.

enzosagretti commented 5 months ago

Hi! Any news about this? Run onnx modelo with one session and multiples request in parallel is slower that do it in sequential mode...

Thanks.

antoninononooono commented 2 months ago

Can you please share your patches as you had announced?