Open fxmarty opened 1 year ago
Updated:
Actually, I can reproduce the issue. Add profiling:
import copy
import time
import onnxruntime as ort
import torch
from transformers import AutoTokenizer
def test(use_ort_in_loop):
batch_size = 4
sequence_length = 32
tokenizer = AutoTokenizer.from_pretrained("t5-small")
total_infsession = []
total_no_infsession = []
n_loop = 100
inp = {
"input_ids": torch.randint(tokenizer.vocab_size - 1, (batch_size, 5), dtype=torch.int64),
"encoder_attention_mask": torch.ones(batch_size, sequence_length, dtype=torch.int64),
"encoder_hidden_states": torch.rand((batch_size, sequence_length, 512), dtype=torch.float32)
}
ort_inp = {}
for key, value in inp.items():
ort_inp[key] = inp[key].detach().cpu().numpy()
session = ort.InferenceSession("decoder_model.onnx", providers=["CPUExecutionProvider"])
M = torch.zeros(32128, 2)
def get_next_token_logits(ort_inp):
res = session.run(None, ort_inp)
out = torch.from_numpy(res[0])
next_token_logits = out[:, -1, :]
# forcing contiguous does not help :-(
# next_token_logits = next_token_logits.contiguous()
return next_token_logits
next_token_logits = get_next_token_logits(ort_inp)
time.sleep(2)
with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
]
) as p:
for i in range(n_loop):
if use_ort_in_loop:
next_token_logits = get_next_token_logits(ort_inp)
start = time.time()
torch.matmul(next_token_logits, M)
total_no_infsession += [time.time() - start]
print(f"use_ort_in_loop={use_ort_in_loop}, operation took {sum(total_no_infsession) / n_loop * 1000} ms over {n_loop} runs")
print(p.key_averages().table( sort_by="self_cpu_time_total", row_limit=-1))
with torch.no_grad():
test(True)
test(False)
The output:
use_ort_in_loop=True, operation took 32.16004133224487 ms over 100 runs
--------------------------- ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls
--------------------------- ------------ ------------ ------------ ------------ ------------ ------------
aten::mm 93.75% 3.214s 93.75% 3.214s 32.140ms 100
cudaDeviceSynchronize 6.07% 208.037ms 6.07% 208.037ms 208.037ms 1
cudaGetDeviceProperties 0.06% 2.213ms 0.06% 2.213ms 2.213ms 1
aten::slice 0.05% 1.772ms 0.06% 2.084ms 10.420us 200
aten::select 0.02% 723.000us 0.02% 823.000us 8.230us 100
aten::matmul 0.02% 515.000us 93.76% 3.215s 32.145ms 100
aten::lift_fresh 0.01% 441.000us 0.01% 441.000us 4.410us 100
aten::as_strided 0.01% 412.000us 0.01% 412.000us 1.373us 300
cudaGetDeviceCount 0.01% 189.000us 0.01% 189.000us 189.000us 1
aten::resolve_conj 0.00% 2.000us 0.00% 2.000us 0.007us 300
--------------------------- ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 3.428s
STAGE:2022-12-01 22:37:36 2951695:2951695 ActivityProfilerController.cpp:294] Completed Stage: Warm Up
STAGE:2022-12-01 22:37:36 2951695:2951695 ActivityProfilerController.cpp:300] Completed Stage: Collection
use_ort_in_loop=False, operation took 0.2479982376098633 ms over 100 runs
------------------------- ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls
------------------------- ------------ ------------ ------------ ------------ ------------ ------------
aten::mm 94.07% 24.398ms 94.07% 24.399ms 243.990us 100
aten::matmul 5.88% 1.525ms 94.46% 24.498ms 244.980us 100
cudaDeviceSynchronize 0.05% 12.000us 0.05% 12.000us 12.000us 1
aten::resolve_conj 0.00% 1.000us 0.00% 1.000us 0.003us 300
------------------------- ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 25.936ms
Right now, I am not sure why there are extra code called (like aten::slice).
One possible cause is ORT session has thread pool, and torch uses threads for matmul. In some way, ORT thread that slows down torch thread.
I'll try again tomorrow, I had the issue with contiguous as well. Edit: interestingly on my home laptop I can't reproduce the x50 slowdown, only x2 (both contiguous and non contiguous).
I did an experiments which shows that reducing ORT threads help: 31ms => 0.3ms. It confirms that the cause is thread conflictions between ORT and Torch.
import time
import onnxruntime as ort
import psutil
import torch
from transformers import AutoTokenizer
def test(
use_ort_in_loop,
num_thread_ort,
num_thread_torch,
contiguous=False,
n_loop=100,
disable_ort_spinning=False,
profiling=False,
):
batch_size = 4
sequence_length = 32
tokenizer = AutoTokenizer.from_pretrained("t5-small")
inp = {
"input_ids": torch.randint(tokenizer.vocab_size - 1, (batch_size, 5), dtype=torch.int64),
"encoder_attention_mask": torch.ones(batch_size, sequence_length, dtype=torch.int64),
"encoder_hidden_states": torch.rand((batch_size, sequence_length, 512), dtype=torch.float32),
}
ort_inp = {}
for key in inp:
ort_inp[key] = inp[key].detach().cpu().numpy()
sess_options = ort.SessionOptions()
if disable_ort_spinning:
sess_options.add_session_config_entry("session.intra_op.allow_spinning", "0")
session = ort.InferenceSession("decoder_model.onnx", sess_options=sess_options, providers=["CPUExecutionProvider"])
M = torch.zeros(32128, 2)
def get_next_token_logits(ort_inp):
res = session.run(None, ort_inp)
out = torch.from_numpy(res[0])
next_token_logits = out[:, -1, :]
return next_token_logits.contiguous() if contiguous else next_token_logits
next_token_logits = get_next_token_logits(ort_inp)
time.sleep(15)
torch.set_num_threads(num_thread_torch)
def run_perf_test(next_token_logits):
latency = []
for _ in range(n_loop):
if use_ort_in_loop:
next_token_logits = get_next_token_logits(ort_inp)
start = time.time()
torch.matmul(next_token_logits, M)
latency += [time.time() - start]
return latency
if profiling:
with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
]
) as p:
latency = run_perf_test(next_token_logits)
print(p.key_averages().table(sort_by="self_cpu_time_total", row_limit=-1))
else:
latency = run_perf_test(next_token_logits)
print(
f"use_ort_in_loop={use_ort_in_loop}, disable_ort_spinning={disable_ort_spinning} ort_threads={num_thread_ort}, torch_threads={num_thread_torch}, operation took {sum(latency) / n_loop * 1000:.1f} ms over {n_loop} runs"
)
with torch.no_grad():
cpu_count = psutil.cpu_count(logical=False)
for disable_spinning in [False, True]:
test(True, int(cpu_count), int(cpu_count), disable_ort_spinning=disable_spinning)
test(True, int(cpu_count / 2), int(cpu_count / 2), disable_ort_spinning=disable_spinning)
test(False, int(cpu_count), int(cpu_count), disable_ort_spinning=disable_spinning)
test(False, int(cpu_count / 2), int(cpu_count / 2), disable_ort_spinning=disable_spinning)
The output:
use_ort_in_loop=True, ort_threads=4, torch_threads=4, operation took 21.24567747116089 ms over 100 runs
--------------------------- ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls
--------------------------- ------------ ------------ ------------ ------------ ------------ ------------
aten::mm 90.60% 2.123s 90.60% 2.123s 21.227ms 100
cudaDeviceSynchronize 9.13% 214.017ms 9.13% 214.017ms 214.017ms 1
cudaGetDeviceProperties 0.09% 2.170ms 0.09% 2.170ms 2.170ms 1
aten::slice 0.07% 1.638ms 0.08% 1.962ms 9.810us 200
aten::select 0.03% 737.000us 0.04% 837.000us 8.370us 100
aten::matmul 0.03% 621.000us 90.62% 2.123s 21.233ms 100
aten::lift_fresh 0.02% 504.000us 0.02% 504.000us 5.040us 100
aten::as_strided 0.02% 424.000us 0.02% 424.000us 1.413us 300
cudaGetDeviceCount 0.01% 197.000us 0.01% 197.000us 197.000us 1
aten::resolve_conj 0.00% 6.000us 0.00% 6.000us 0.020us 300
--------------------------- ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 2.343s
STAGE:2022-12-01 23:38:03 2962600:2962600 ActivityProfilerController.cpp:294] Completed Stage: Warm Up
STAGE:2022-12-01 23:38:05 2962600:2962600 ActivityProfilerController.cpp:300] Completed Stage: Collection
use_ort_in_loop=True, ort_threads=2, torch_threads=2, operation took 0.32085180282592773 ms over 100 runs
------------------------- ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls
------------------------- ------------ ------------ ------------ ------------ ------------ ------------
aten::mm 89.22% 30.525ms 89.22% 30.526ms 305.260us 100
aten::slice 4.70% 1.609ms 5.64% 1.928ms 9.640us 200
aten::select 2.00% 685.000us 2.31% 790.000us 7.900us 100
aten::matmul 1.48% 507.000us 90.71% 31.033ms 310.330us 100
aten::lift_fresh 1.31% 447.000us 1.31% 447.000us 4.470us 100
aten::as_strided 1.24% 424.000us 1.24% 424.000us 1.413us 300
cudaDeviceSynchronize 0.04% 15.000us 0.04% 15.000us 15.000us 1
aten::resolve_conj 0.00% 1.000us 0.00% 1.000us 0.003us 300
------------------------- ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 34.213ms
STAGE:2022-12-01 23:38:21 2962600:2962600 ActivityProfilerController.cpp:294] Completed Stage: Warm Up
STAGE:2022-12-01 23:38:21 2962600:2962600 ActivityProfilerController.cpp:300] Completed Stage: Collection
use_ort_in_loop=False, ort_threads=4, torch_threads=4, operation took 0.8952760696411133 ms over 100 runs
------------------------- ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls
------------------------- ------------ ------------ ------------ ------------ ------------ ------------
aten::mm 99.33% 89.091ms 99.33% 89.092ms 890.920us 100
aten::matmul 0.65% 587.000us 99.46% 89.203ms 892.030us 100
cudaDeviceSynchronize 0.01% 12.000us 0.01% 12.000us 12.000us 1
aten::resolve_conj 0.00% 1.000us 0.00% 1.000us 0.003us 300
------------------------- ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 89.691ms
STAGE:2022-12-01 23:38:38 2962600:2962600 ActivityProfilerController.cpp:294] Completed Stage: Warm Up
STAGE:2022-12-01 23:38:38 2962600:2962600 ActivityProfilerController.cpp:300] Completed Stage: Collection
use_ort_in_loop=False, ort_threads=2, torch_threads=2, operation took 0.2782249450683594 ms over 100 runs
------------------------- ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls
------------------------- ------------ ------------ ------------ ------------ ------------ ------------
aten::mm 96.15% 27.365ms 96.15% 27.366ms 273.660us 100
aten::matmul 3.81% 1.084ms 96.57% 27.485ms 274.850us 100
cudaDeviceSynchronize 0.04% 12.000us 0.04% 12.000us 12.000us 1
aten::resolve_conj 0.00% 1.000us 0.00% 1.000us 0.003us 300
------------------------- ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 28.462ms
Thanks a lot for the followup and better profiling! Do you see it as being a bug/issue in ONNX Runtime?
I do not think it is a bug. It is expected that when total threads of a process are larger than number of cpu cores, the performance will be impacted due to thread context switch.
To achieve better performance, I suggest experiments on thread setting of ORT and PyTorch & Numpy. https://pytorch.org/docs/stable/notes/cpu_threading_torchscript_inference.html https://stackoverflow.com/questions/30791550/limit-number-of-threads-in-numpy
ORT python API has two settings (inter_op_num_threads and intra_op_num_threads): https://onnxruntime.ai/docs/api/python/api_summary.html#onnxruntime.SessionOptions.inter_op_num_threads
@pranavsharma, any suggestions on this topic?
Couple of options:
del session
? That'll cause the session and hence the ORT threadpool to be destructed.Thanks for your answers and suggestion! Will try it out. Will check if PyTorch + Numpy have the same issue or not.
Does this mean that torchdynamo + ort as execution backend will be very bad (by default) on CPU in case the whole graph is not traced (edit: spoiler: IMO yes 😅)? Or does dynamo works differently than spawning InferenceSessions?
import torch._dynamo as dynamo
dynamo_model = dynamo.optimize("onnxrt")(model)
Relevant torchdynamo backend: https://github.com/pytorch/pytorch/blob/1ee189ce8ebf392b4a9c026f040b14f7145ca5e6/torch/_dynamo/optimizations/backends.py#L135-L176 (looks weird, there's IOBinding even on CPU @ezyang)
@fxmarty, add the following setting seems help:
sess_options.add_session_config_entry("session.intra_op.allow_spinning", "0")
Updated the previous script inline. Here is the result:
use_ort_in_loop=True, disable_ort_spinning=False ort_threads=4, torch_threads=4, operation took 23.15556049346924 ms over 100 runs
use_ort_in_loop=True, disable_ort_spinning=False ort_threads=2, torch_threads=2, operation took 4.832360744476318 ms over 100 runs
use_ort_in_loop=False, disable_ort_spinning=False ort_threads=4, torch_threads=4, operation took 0.7350826263427734 ms over 100 runs
use_ort_in_loop=False, disable_ort_spinning=False ort_threads=2, torch_threads=2, operation took 0.24795770645141602 ms over 100 runs
use_ort_in_loop=True, disable_ort_spinning=True ort_threads=4, torch_threads=4, operation took 0.8697772026062012 ms over 100 runs
use_ort_in_loop=True, disable_ort_spinning=True ort_threads=2, torch_threads=2, operation took 0.33701181411743164 ms over 100 runs
use_ort_in_loop=False, disable_ort_spinning=True ort_threads=4, torch_threads=4, operation took 0.7239651679992676 ms over 100 runs
use_ort_in_loop=False, disable_ort_spinning=True ort_threads=2, torch_threads=2, operation took 0.24130821228027344 ms over 100 runs
We haven't really heavily tested the ORT backend so there may be flagrant performance problems for no good reason. One thing is that you should imagine what the performance of your program would be if you replaced your model with several calls into ORT backend. That should give a sense for what the perf is
Here's an updated script:
TL;DR, for this specific model and specific pytorch operation:
@tianleiwu On which hardware/CPU did you reproduce?
Note: the issue can be reproduced on Intel(R) Xeon(R) Platinum 8255C CPU @ 2.50GHz
Describe the issue
Doing PyTorch/Numpy operations on tensors obtained by
InferenceSession.run()
is 50x slower than doing these operations from dummy inputs.Doing
time.sleep()
after running with the InferenceSession solves the problem.I am surprised by this as I don't think InferenceSession.run() is asynchronous.
Related issue: https://github.com/huggingface/optimum/issues/524
To reproduce
And run:
Output:
Reproduction using
np.matmul
instead oftorch.matmul
, same issue:Urgency
high
Platform
Linux
OS Version
Ubuntu 22.04 jammy, x86_64 Linux 5.15.0-53-generic
ONNX Runtime Installation
Released Package
ONNX Runtime Version or Commit ID
1.13.1
ONNX Runtime API
Python
Architecture
X64
Execution Provider
Default CPU
Execution Provider Library Version
No response
Model File
https://huggingface.co/fxmarty/t5-small-onnx/resolve/main/decoder_model.onnx
Is this a quantized model?
No