microsoft / onnxruntime

ONNX Runtime: cross-platform, high performance ML inferencing and training accelerator
https://onnxruntime.ai
MIT License
14.33k stars 2.87k forks source link

[CPUExecutionProvider] PyTorch/Numpy operations following InferenceSession.run() are 50x slower compared to using dummy inputs #13808

Open fxmarty opened 1 year ago

fxmarty commented 1 year ago

Describe the issue

Doing PyTorch/Numpy operations on tensors obtained by InferenceSession.run() is 50x slower than doing these operations from dummy inputs.

Doing time.sleep() after running with the InferenceSession solves the problem.

I am surprised by this as I don't think InferenceSession.run() is asynchronous.

Related issue: https://github.com/huggingface/optimum/issues/524

To reproduce

pip install transformers
mkdir t5_onnx && cd t5_onnx
wget https://huggingface.co/fxmarty/t5-small-onnx/resolve/main/decoder_model.onnx

And run:

import onnxruntime as ort

from transformers import AutoTokenizer
import torch
import time
import copy

batch_size = 4
sequence_length = 32

tokenizer = AutoTokenizer.from_pretrained("t5-small")

total_infsession = 0
total_no_infsession = 0
n_loop = 10

inp = {
    "input_ids": torch.randint(tokenizer.vocab_size - 1, (batch_size, 5), dtype=torch.int64),
    "encoder_attention_mask": torch.ones(batch_size, sequence_length, dtype=torch.int64),
    "encoder_hidden_states": torch.rand((batch_size, sequence_length, 512), dtype=torch.float32)
}

ort_inp = {}
for key, value in inp.items():
    ort_inp[key] = inp[key].detach().cpu().numpy()

session = ort.InferenceSession("/home/fxmarty/t5_onnx/decoder_model.onnx", providers=["CPUExecutionProvider"])

M = torch.zeros(32128, 2)

def get_next_token_logits(ort_inp):
    res = session.run(None, ort_inp)
    out = torch.from_numpy(res[0])
    next_token_logits = out[:, -1, :]

    # forcing contiguous does not help :-(
    # next_token_logits = next_token_logits.contiguous()

    return next_token_logits

for i in range(n_loop):
    next_token_logits = get_next_token_logits(ort_inp)

    # del session
    # time.sleep(1)

    print(next_token_logits[0][0], "contiguous:", next_token_logits.is_contiguous())
    start = time.time()
    torch.matmul(next_token_logits, M)
    total_infsession += time.time() - start

next_token_logits = get_next_token_logits(ort_inp)
time.sleep(2)

for i in range(n_loop):
    print(next_token_logits[0][0], "contiguous:", next_token_logits.is_contiguous())
    start = time.time()
    torch.matmul(next_token_logits, M)
    total_no_infsession += time.time() - start

print(f"WITH inference session: operation took {total_infsession * 1e3:.2f} ms over {n_loop} runs" )
print(f"WITHOUT inference session: operation took {total_no_infsession * 1e3:.2f} ms over {n_loop} runs" )

Output:

tensor(-20.8427) contiguous: False
tensor(-20.8427) contiguous: False
tensor(-20.8427) contiguous: False
tensor(-20.8427) contiguous: False
tensor(-20.8427) contiguous: False
tensor(-20.8427) contiguous: False
tensor(-20.8427) contiguous: False
tensor(-20.8427) contiguous: False
tensor(-20.8427) contiguous: False
tensor(-20.8427) contiguous: False
tensor(-20.8427) contiguous: False
tensor(-20.8427) contiguous: False
tensor(-20.8427) contiguous: False
tensor(-20.8427) contiguous: False
tensor(-20.8427) contiguous: False
tensor(-20.8427) contiguous: False
tensor(-20.8427) contiguous: False
tensor(-20.8427) contiguous: False
tensor(-20.8427) contiguous: False
tensor(-20.8427) contiguous: False
WITH inference session: operation took 187.24 ms over 10 runs
WITHOUT inference session: operation took 6.19 ms over 10 runs

Reproduction using np.matmul instead of torch.matmul, same issue:

import onnxruntime as ort

from transformers import AutoTokenizer
import torch
import time
import copy

import numpy as np

batch_size = 4
sequence_length = 32

tokenizer = AutoTokenizer.from_pretrained("t5-small")

total_infsession = 0
total_no_infsession = 0
n_loop = 20

inp = {
    "input_ids": torch.randint(tokenizer.vocab_size - 1, (batch_size, 5), dtype=torch.int64),
    "encoder_attention_mask": torch.ones(batch_size, sequence_length, dtype=torch.int64),
    "encoder_hidden_states": torch.rand((batch_size, sequence_length, 512), dtype=torch.float32)
}

ort_inp = {}
for key, value in inp.items():
    ort_inp[key] = inp[key].detach().cpu().numpy()

session = ort.InferenceSession("/home/fxmarty/t5_onnx/decoder_model.onnx", providers=["CPUExecutionProvider"])

M = np.zeros((32128, 2))

def get_next_token_logits(ort_inp):
    res = session.run(None, ort_inp)
    next_token_logits = res[0][:, -1, :]

    return next_token_logits

for i in range(n_loop):
    next_token_logits = get_next_token_logits(ort_inp)

    # del session
    # time.sleep(1)

    print(next_token_logits[0][0], "contiguous:", next_token_logits.data.contiguous)
    start = time.time()
    np.matmul(next_token_logits, M)
    total_infsession += time.time() - start

next_token_logits = get_next_token_logits(ort_inp)
time.sleep(2)

for i in range(n_loop):    
    print(next_token_logits[0][0], "contiguous:", next_token_logits.data.contiguous)
    start = time.time()
    _ = np.matmul(next_token_logits, M)
    total_no_infsession += time.time() - start

print(f"WITH inference session: operation took {total_infsession * 1e3:.2f} ms over {n_loop} runs" )
print(f"WITHOUT inference session: operation took {total_no_infsession * 1e3:.2f} ms over {n_loop} runs" )

Urgency

high

Platform

Linux

OS Version

Ubuntu 22.04 jammy, x86_64 Linux 5.15.0-53-generic

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

1.13.1

ONNX Runtime API

Python

Architecture

X64

Execution Provider

Default CPU

Execution Provider Library Version

No response

Model File

https://huggingface.co/fxmarty/t5-small-onnx/resolve/main/decoder_model.onnx

Is this a quantized model?

No

tianleiwu commented 1 year ago

Updated:

Actually, I can reproduce the issue. Add profiling:

import copy
import time

import onnxruntime as ort
import torch
from transformers import AutoTokenizer

def test(use_ort_in_loop):
    batch_size = 4
    sequence_length = 32

    tokenizer = AutoTokenizer.from_pretrained("t5-small")

    total_infsession = []
    total_no_infsession = []
    n_loop = 100

    inp = {
        "input_ids": torch.randint(tokenizer.vocab_size - 1, (batch_size, 5), dtype=torch.int64),
        "encoder_attention_mask": torch.ones(batch_size, sequence_length, dtype=torch.int64),
        "encoder_hidden_states": torch.rand((batch_size, sequence_length, 512), dtype=torch.float32)
    }

    ort_inp = {}
    for key, value in inp.items():
        ort_inp[key] = inp[key].detach().cpu().numpy()

    session = ort.InferenceSession("decoder_model.onnx", providers=["CPUExecutionProvider"])

    M = torch.zeros(32128, 2)

    def get_next_token_logits(ort_inp):
        res = session.run(None, ort_inp)
        out = torch.from_numpy(res[0])
        next_token_logits = out[:, -1, :]

        # forcing contiguous does not help :-(
        # next_token_logits = next_token_logits.contiguous()

        return next_token_logits

    next_token_logits = get_next_token_logits(ort_inp)
    time.sleep(2)
    with torch.profiler.profile(
        activities=[
            torch.profiler.ProfilerActivity.CPU,
            torch.profiler.ProfilerActivity.CUDA,
        ]
    ) as p:
        for i in range(n_loop):
            if use_ort_in_loop:
                next_token_logits = get_next_token_logits(ort_inp)
            start = time.time()
            torch.matmul(next_token_logits, M)
            total_no_infsession += [time.time() - start]
    print(f"use_ort_in_loop={use_ort_in_loop}, operation took {sum(total_no_infsession) / n_loop * 1000} ms over {n_loop} runs")
    print(p.key_averages().table( sort_by="self_cpu_time_total", row_limit=-1))

with torch.no_grad():
    test(True)
    test(False)

The output:

use_ort_in_loop=True, operation took 32.16004133224487 ms over 100 runs
---------------------------  ------------  ------------  ------------  ------------  ------------  ------------
                       Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls
---------------------------  ------------  ------------  ------------  ------------  ------------  ------------
                   aten::mm        93.75%        3.214s        93.75%        3.214s      32.140ms           100
      cudaDeviceSynchronize         6.07%     208.037ms         6.07%     208.037ms     208.037ms             1
    cudaGetDeviceProperties         0.06%       2.213ms         0.06%       2.213ms       2.213ms             1
                aten::slice         0.05%       1.772ms         0.06%       2.084ms      10.420us           200
               aten::select         0.02%     723.000us         0.02%     823.000us       8.230us           100
               aten::matmul         0.02%     515.000us        93.76%        3.215s      32.145ms           100
           aten::lift_fresh         0.01%     441.000us         0.01%     441.000us       4.410us           100
           aten::as_strided         0.01%     412.000us         0.01%     412.000us       1.373us           300
         cudaGetDeviceCount         0.01%     189.000us         0.01%     189.000us     189.000us             1
         aten::resolve_conj         0.00%       2.000us         0.00%       2.000us       0.007us           300
---------------------------  ------------  ------------  ------------  ------------  ------------  ------------
Self CPU time total: 3.428s

STAGE:2022-12-01 22:37:36 2951695:2951695 ActivityProfilerController.cpp:294] Completed Stage: Warm Up
STAGE:2022-12-01 22:37:36 2951695:2951695 ActivityProfilerController.cpp:300] Completed Stage: Collection
use_ort_in_loop=False, operation took 0.2479982376098633 ms over 100 runs
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------
                     Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------
                 aten::mm        94.07%      24.398ms        94.07%      24.399ms     243.990us           100
             aten::matmul         5.88%       1.525ms        94.46%      24.498ms     244.980us           100
    cudaDeviceSynchronize         0.05%      12.000us         0.05%      12.000us      12.000us             1
       aten::resolve_conj         0.00%       1.000us         0.00%       1.000us       0.003us           300
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------
Self CPU time total: 25.936ms

Right now, I am not sure why there are extra code called (like aten::slice).

One possible cause is ORT session has thread pool, and torch uses threads for matmul. In some way, ORT thread that slows down torch thread.

fxmarty commented 1 year ago

I'll try again tomorrow, I had the issue with contiguous as well. Edit: interestingly on my home laptop I can't reproduce the x50 slowdown, only x2 (both contiguous and non contiguous).

tianleiwu commented 1 year ago

I did an experiments which shows that reducing ORT threads help: 31ms => 0.3ms. It confirms that the cause is thread conflictions between ORT and Torch.

import time

import onnxruntime as ort
import psutil
import torch
from transformers import AutoTokenizer

def test(
    use_ort_in_loop,
    num_thread_ort,
    num_thread_torch,
    contiguous=False,
    n_loop=100,
    disable_ort_spinning=False,
    profiling=False,
):
    batch_size = 4
    sequence_length = 32

    tokenizer = AutoTokenizer.from_pretrained("t5-small")

    inp = {
        "input_ids": torch.randint(tokenizer.vocab_size - 1, (batch_size, 5), dtype=torch.int64),
        "encoder_attention_mask": torch.ones(batch_size, sequence_length, dtype=torch.int64),
        "encoder_hidden_states": torch.rand((batch_size, sequence_length, 512), dtype=torch.float32),
    }

    ort_inp = {}
    for key in inp:
        ort_inp[key] = inp[key].detach().cpu().numpy()

    sess_options = ort.SessionOptions()
    if disable_ort_spinning:
        sess_options.add_session_config_entry("session.intra_op.allow_spinning", "0")

    session = ort.InferenceSession("decoder_model.onnx", sess_options=sess_options, providers=["CPUExecutionProvider"])

    M = torch.zeros(32128, 2)

    def get_next_token_logits(ort_inp):
        res = session.run(None, ort_inp)
        out = torch.from_numpy(res[0])
        next_token_logits = out[:, -1, :]
        return next_token_logits.contiguous() if contiguous else next_token_logits

    next_token_logits = get_next_token_logits(ort_inp)
    time.sleep(15)

    torch.set_num_threads(num_thread_torch)

    def run_perf_test(next_token_logits):
        latency = []
        for _ in range(n_loop):
            if use_ort_in_loop:
                next_token_logits = get_next_token_logits(ort_inp)
            start = time.time()
            torch.matmul(next_token_logits, M)
            latency += [time.time() - start]
        return latency

    if profiling:
        with torch.profiler.profile(
            activities=[
                torch.profiler.ProfilerActivity.CPU,
                torch.profiler.ProfilerActivity.CUDA,
            ]
        ) as p:
            latency = run_perf_test(next_token_logits)
        print(p.key_averages().table(sort_by="self_cpu_time_total", row_limit=-1))
    else:
        latency = run_perf_test(next_token_logits)

    print(
        f"use_ort_in_loop={use_ort_in_loop}, disable_ort_spinning={disable_ort_spinning} ort_threads={num_thread_ort}, torch_threads={num_thread_torch}, operation took {sum(latency) / n_loop * 1000:.1f} ms over {n_loop} runs"
    )

with torch.no_grad():
    cpu_count = psutil.cpu_count(logical=False)
    for disable_spinning in [False, True]:
        test(True, int(cpu_count), int(cpu_count), disable_ort_spinning=disable_spinning)
        test(True, int(cpu_count / 2), int(cpu_count / 2), disable_ort_spinning=disable_spinning)
        test(False, int(cpu_count), int(cpu_count), disable_ort_spinning=disable_spinning)
        test(False, int(cpu_count / 2), int(cpu_count / 2), disable_ort_spinning=disable_spinning)

The output:

use_ort_in_loop=True, ort_threads=4, torch_threads=4, operation took 21.24567747116089 ms over 100 runs
---------------------------  ------------  ------------  ------------  ------------  ------------  ------------
                       Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls
---------------------------  ------------  ------------  ------------  ------------  ------------  ------------
                   aten::mm        90.60%        2.123s        90.60%        2.123s      21.227ms           100
      cudaDeviceSynchronize         9.13%     214.017ms         9.13%     214.017ms     214.017ms             1
    cudaGetDeviceProperties         0.09%       2.170ms         0.09%       2.170ms       2.170ms             1
                aten::slice         0.07%       1.638ms         0.08%       1.962ms       9.810us           200
               aten::select         0.03%     737.000us         0.04%     837.000us       8.370us           100
               aten::matmul         0.03%     621.000us        90.62%        2.123s      21.233ms           100
           aten::lift_fresh         0.02%     504.000us         0.02%     504.000us       5.040us           100
           aten::as_strided         0.02%     424.000us         0.02%     424.000us       1.413us           300
         cudaGetDeviceCount         0.01%     197.000us         0.01%     197.000us     197.000us             1
         aten::resolve_conj         0.00%       6.000us         0.00%       6.000us       0.020us           300
---------------------------  ------------  ------------  ------------  ------------  ------------  ------------
Self CPU time total: 2.343s

STAGE:2022-12-01 23:38:03 2962600:2962600 ActivityProfilerController.cpp:294] Completed Stage: Warm Up
STAGE:2022-12-01 23:38:05 2962600:2962600 ActivityProfilerController.cpp:300] Completed Stage: Collection
use_ort_in_loop=True, ort_threads=2, torch_threads=2, operation took 0.32085180282592773 ms over 100 runs
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------
                     Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------
                 aten::mm        89.22%      30.525ms        89.22%      30.526ms     305.260us           100
              aten::slice         4.70%       1.609ms         5.64%       1.928ms       9.640us           200
             aten::select         2.00%     685.000us         2.31%     790.000us       7.900us           100
             aten::matmul         1.48%     507.000us        90.71%      31.033ms     310.330us           100
         aten::lift_fresh         1.31%     447.000us         1.31%     447.000us       4.470us           100
         aten::as_strided         1.24%     424.000us         1.24%     424.000us       1.413us           300
    cudaDeviceSynchronize         0.04%      15.000us         0.04%      15.000us      15.000us             1
       aten::resolve_conj         0.00%       1.000us         0.00%       1.000us       0.003us           300
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------
Self CPU time total: 34.213ms

STAGE:2022-12-01 23:38:21 2962600:2962600 ActivityProfilerController.cpp:294] Completed Stage: Warm Up
STAGE:2022-12-01 23:38:21 2962600:2962600 ActivityProfilerController.cpp:300] Completed Stage: Collection
use_ort_in_loop=False, ort_threads=4, torch_threads=4, operation took 0.8952760696411133 ms over 100 runs
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------
                     Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------
                 aten::mm        99.33%      89.091ms        99.33%      89.092ms     890.920us           100
             aten::matmul         0.65%     587.000us        99.46%      89.203ms     892.030us           100
    cudaDeviceSynchronize         0.01%      12.000us         0.01%      12.000us      12.000us             1
       aten::resolve_conj         0.00%       1.000us         0.00%       1.000us       0.003us           300
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------
Self CPU time total: 89.691ms

STAGE:2022-12-01 23:38:38 2962600:2962600 ActivityProfilerController.cpp:294] Completed Stage: Warm Up
STAGE:2022-12-01 23:38:38 2962600:2962600 ActivityProfilerController.cpp:300] Completed Stage: Collection
use_ort_in_loop=False, ort_threads=2, torch_threads=2, operation took 0.2782249450683594 ms over 100 runs
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------
                     Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------
                 aten::mm        96.15%      27.365ms        96.15%      27.366ms     273.660us           100
             aten::matmul         3.81%       1.084ms        96.57%      27.485ms     274.850us           100
    cudaDeviceSynchronize         0.04%      12.000us         0.04%      12.000us      12.000us             1
       aten::resolve_conj         0.00%       1.000us         0.00%       1.000us       0.003us           300
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------
Self CPU time total: 28.462ms
fxmarty commented 1 year ago

Thanks a lot for the followup and better profiling! Do you see it as being a bug/issue in ONNX Runtime?

tianleiwu commented 1 year ago

I do not think it is a bug. It is expected that when total threads of a process are larger than number of cpu cores, the performance will be impacted due to thread context switch.

To achieve better performance, I suggest experiments on thread setting of ORT and PyTorch & Numpy. https://pytorch.org/docs/stable/notes/cpu_threading_torchscript_inference.html https://stackoverflow.com/questions/30791550/limit-number-of-threads-in-numpy

ORT python API has two settings (inter_op_num_threads and intra_op_num_threads): https://onnxruntime.ai/docs/api/python/api_summary.html#onnxruntime.SessionOptions.inter_op_num_threads

@pranavsharma, any suggestions on this topic?

pranavsharma commented 1 year ago

Couple of options:

  1. Is it possible to delete the session object once you're done with inferencing by calling del session? That'll cause the session and hence the ORT threadpool to be destructed.
  2. Another option is to pin pytorch and ORT threads to different cores, but we don't support that in ORT today (it's work in progress).
  3. You can also try by turning off spinning using this config.
fxmarty commented 1 year ago

Thanks for your answers and suggestion! Will try it out. Will check if PyTorch + Numpy have the same issue or not.

Does this mean that torchdynamo + ort as execution backend will be very bad (by default) on CPU in case the whole graph is not traced (edit: spoiler: IMO yes 😅)? Or does dynamo works differently than spawning InferenceSessions?

import torch._dynamo as dynamo

dynamo_model = dynamo.optimize("onnxrt")(model)

Relevant torchdynamo backend: https://github.com/pytorch/pytorch/blob/1ee189ce8ebf392b4a9c026f040b14f7145ca5e6/torch/_dynamo/optimizations/backends.py#L135-L176 (looks weird, there's IOBinding even on CPU @ezyang)

tianleiwu commented 1 year ago

@fxmarty, add the following setting seems help:

sess_options.add_session_config_entry("session.intra_op.allow_spinning", "0")

Updated the previous script inline. Here is the result:

use_ort_in_loop=True, disable_ort_spinning=False ort_threads=4, torch_threads=4, operation took 23.15556049346924 ms over 100 runs
use_ort_in_loop=True, disable_ort_spinning=False ort_threads=2, torch_threads=2, operation took 4.832360744476318 ms over 100 runs
use_ort_in_loop=False, disable_ort_spinning=False ort_threads=4, torch_threads=4, operation took 0.7350826263427734 ms over 100 runs
use_ort_in_loop=False, disable_ort_spinning=False ort_threads=2, torch_threads=2, operation took 0.24795770645141602 ms over 100 runs
use_ort_in_loop=True, disable_ort_spinning=True ort_threads=4, torch_threads=4, operation took 0.8697772026062012 ms over 100 runs
use_ort_in_loop=True, disable_ort_spinning=True ort_threads=2, torch_threads=2, operation took 0.33701181411743164 ms over 100 runs
use_ort_in_loop=False, disable_ort_spinning=True ort_threads=4, torch_threads=4, operation took 0.7239651679992676 ms over 100 runs
use_ort_in_loop=False, disable_ort_spinning=True ort_threads=2, torch_threads=2, operation took 0.24130821228027344 ms over 100 runs
ezyang commented 1 year ago

We haven't really heavily tested the ORT backend so there may be flagrant performance problems for no good reason. One thing is that you should imagine what the performance of your program would be if you replaced your model with several calls into ORT backend. That should give a sense for what the perf is

fxmarty commented 1 year ago

Here's an updated script:

Script ```python import time import onnxruntime as ort import psutil import torch from transformers import AutoTokenizer def test( use_ort_in_loop, use_pt_in_loop, num_thread_ort, num_thread_torch, contiguous=True, n_loop=200, disable_ort_spinning=False, profiling=True, ): batch_size = 4 sequence_length = 32 tokenizer = AutoTokenizer.from_pretrained("t5-small") inp = { "input_ids": torch.randint(tokenizer.vocab_size - 1, (batch_size, 5), dtype=torch.int64), "encoder_attention_mask": torch.ones(batch_size, sequence_length, dtype=torch.int64), "encoder_hidden_states": torch.rand((batch_size, sequence_length, 512), dtype=torch.float32), } ort_inp = {} for key in inp: ort_inp[key] = inp[key].detach().cpu().numpy() sess_options = ort.SessionOptions() sess_options.intra_op_num_threads = num_thread_ort if disable_ort_spinning: sess_options.add_session_config_entry("session.intra_op.allow_spinning", "0") session = ort.InferenceSession("decoder_model.onnx", sess_options=sess_options, providers=["CPUExecutionProvider"]) M = torch.zeros(32128, 100) def get_next_token_logits(ort_inp): res = session.run(None, ort_inp) next_token_logits = torch.from_numpy(res[0]) return next_token_logits.contiguous() if contiguous else next_token_logits next_token_logits = get_next_token_logits(ort_inp) time.sleep(10) torch.set_num_threads(num_thread_torch) def run_perf_test(next_token_logits): latency = [0] ort_latency = [0] for _ in range(n_loop): if use_ort_in_loop: start = time.time() next_token_logits = get_next_token_logits(ort_inp) ort_latency += [time.time() - start] if use_pt_in_loop: start = time.time() torch.matmul(next_token_logits, M) latency += [time.time() - start] return latency, ort_latency if profiling: with torch.profiler.profile( activities=[ torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA, ] ) as p: latency, ort_latency = run_perf_test(next_token_logits) print(p.key_averages().table(sort_by="self_cpu_time_total", row_limit=-1)) else: latency, ort_latency = run_perf_test(next_token_logits) print( f"use_ort={use_ort_in_loop}, use_pt={use_pt_in_loop}, disable_ort_spinning={disable_ort_spinning} ort_threads={num_thread_ort}, torch_threads={num_thread_torch}, PT operation took {sum(latency) * 1e3:.1f} ms, ORT operation took {sum(ort_latency) * 1e3:.1f} ms over {n_loop} runs" ) print("------------") with torch.no_grad(): cpu_count = psutil.cpu_count(logical=False) for disable_spinning in [False, True]: test(True, True, int(cpu_count), int(cpu_count), disable_ort_spinning=disable_spinning) test(True, True, int(cpu_count) // 2, int(cpu_count) // 2, disable_ort_spinning=disable_spinning) ```
Results on AWS EC2 c6i instance (with a piece of Platinum 8375C CPU, 4 physical cores) ``` STAGE:2022-12-05 11:21:41 5513:5513 ActivityProfilerController.cpp:294] Completed Stage: Warm Up STAGE:2022-12-05 11:21:43 5513:5513 ActivityProfilerController.cpp:300] Completed Stage: Collection ------------------------ ------------ ------------ ------------ ------------ ------------ ------------ Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls ------------------------ ------------ ------------ ------------ ------------ ------------ ------------ aten::mm 97.13% 242.705ms 97.13% 242.709ms 1.214ms 200 aten::matmul 1.04% 2.611ms 99.72% 249.184ms 1.246ms 200 aten::reshape 0.64% 1.597ms 1.08% 2.700ms 13.500us 200 aten::_reshape_alias 0.49% 1.228ms 0.49% 1.228ms 6.140us 200 aten::_unsafe_view 0.42% 1.039ms 0.42% 1.039ms 5.195us 200 aten::lift_fresh 0.28% 701.000us 0.28% 701.000us 3.505us 200 aten::resolve_conj 0.00% 4.000us 0.00% 4.000us 0.007us 600 ------------------------ ------------ ------------ ------------ ------------ ------------ ------------ Self CPU time total: 249.885ms use_ort=True, use_pt=True, disable_ort_spinning=False ort_threads=4, torch_threads=4, PT operation took 251.7 ms, ORT operation took 2005.3 ms over 200 runs ------------ STAGE:2022-12-05 11:21:53 5513:5513 ActivityProfilerController.cpp:294] Completed Stage: Warm Up STAGE:2022-12-05 11:21:57 5513:5513 ActivityProfilerController.cpp:300] Completed Stage: Collection ------------------------ ------------ ------------ ------------ ------------ ------------ ------------ Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls ------------------------ ------------ ------------ ------------ ------------ ------------ ------------ aten::mm 98.37% 459.505ms 98.37% 459.507ms 2.298ms 200 aten::matmul 0.60% 2.815ms 99.83% 466.335ms 2.332ms 200 aten::reshape 0.35% 1.627ms 0.61% 2.843ms 14.215us 200 aten::_reshape_alias 0.29% 1.333ms 0.29% 1.333ms 6.665us 200 aten::_unsafe_view 0.23% 1.053ms 0.23% 1.053ms 5.265us 200 aten::lift_fresh 0.17% 776.000us 0.17% 776.000us 3.880us 200 aten::resolve_conj 0.00% 2.000us 0.00% 2.000us 0.003us 600 ------------------------ ------------ ------------ ------------ ------------ ------------ ------------ Self CPU time total: 467.111ms use_ort=True, use_pt=True, disable_ort_spinning=False ort_threads=2, torch_threads=2, PT operation took 469.0 ms, ORT operation took 3139.4 ms over 200 runs ------------ STAGE:2022-12-05 11:22:08 5513:5513 ActivityProfilerController.cpp:294] Completed Stage: Warm Up STAGE:2022-12-05 11:22:11 5513:5513 ActivityProfilerController.cpp:300] Completed Stage: Collection ------------------------ ------------ ------------ ------------ ------------ ------------ ------------ Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls ------------------------ ------------ ------------ ------------ ------------ ------------ ------------ aten::mm 96.50% 199.427ms 96.50% 199.431ms 997.155us 200 aten::matmul 1.22% 2.513ms 99.59% 205.809ms 1.029ms 200 aten::reshape 0.77% 1.592ms 1.41% 2.905ms 14.525us 200 aten::_reshape_alias 0.66% 1.369ms 0.66% 1.369ms 6.845us 200 aten::_unsafe_view 0.44% 904.000us 0.44% 904.000us 4.520us 200 aten::lift_fresh 0.41% 846.000us 0.41% 846.000us 4.230us 200 aten::resolve_conj 0.00% 4.000us 0.00% 4.000us 0.007us 600 ------------------------ ------------ ------------ ------------ ------------ ------------ ------------ Self CPU time total: 206.655ms use_ort=True, use_pt=True, disable_ort_spinning=True ort_threads=4, torch_threads=4, PT operation took 208.3 ms, ORT operation took 2891.3 ms over 200 runs ------------ STAGE:2022-12-05 11:22:21 5513:5513 ActivityProfilerController.cpp:294] Completed Stage: Warm Up STAGE:2022-12-05 11:22:25 5513:5513 ActivityProfilerController.cpp:300] Completed Stage: Collection ------------------------ ------------ ------------ ------------ ------------ ------------ ------------ Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls ------------------------ ------------ ------------ ------------ ------------ ------------ ------------ aten::mm 97.57% 282.228ms 97.57% 282.231ms 1.411ms 200 aten::matmul 0.88% 2.538ms 99.74% 288.486ms 1.442ms 200 aten::reshape 0.54% 1.575ms 0.94% 2.722ms 13.610us 200 aten::_reshape_alias 0.42% 1.227ms 0.42% 1.227ms 6.135us 200 aten::_unsafe_view 0.32% 915.000us 0.32% 915.000us 4.575us 200 aten::lift_fresh 0.26% 765.000us 0.26% 765.000us 3.825us 200 aten::resolve_conj 0.00% 3.000us 0.00% 3.000us 0.005us 600 ------------------------ ------------ ------------ ------------ ------------ ------------ ------------ Self CPU time total: 289.251ms use_ort=True, use_pt=True, disable_ort_spinning=True ort_threads=2, torch_threads=2, PT operation took 291.0 ms, ORT operation took 3259.3 ms over 200 runs ------------ ```
Results on my laptop (with i7-1280P, that has both E-cores and P-cores) ``` --------------------------- ------------ ------------ ------------ ------------ ------------ ------------ Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls --------------------------- ------------ ------------ ------------ ------------ ------------ ------------ aten::mm 97.80% 6.796s 97.80% 6.796s 33.979ms 200 cudaDeviceSynchronize 1.74% 121.127ms 1.74% 121.127ms 121.127ms 1 aten::matmul 0.36% 25.196ms 98.24% 6.826s 34.132ms 200 aten::reshape 0.03% 2.187ms 0.06% 3.918ms 19.590us 200 aten::_reshape_alias 0.03% 1.744ms 0.03% 1.744ms 8.720us 200 aten::_unsafe_view 0.02% 1.420ms 0.02% 1.420ms 7.100us 200 aten::lift_fresh 0.01% 1.006ms 0.01% 1.006ms 5.030us 200 cudaGetDeviceCount 0.00% 163.000us 0.00% 163.000us 163.000us 1 cudaGetDeviceProperties 0.00% 95.000us 0.00% 95.000us 95.000us 1 aten::resolve_conj 0.00% 36.000us 0.00% 36.000us 0.060us 600 --------------------------- ------------ ------------ ------------ ------------ ------------ ------------ Self CPU time total: 6.949s use_ort=True, use_pt=True, disable_ort_spinning=False ort_threads=14, torch_threads=14, PT operation took 6830.3 ms, ORT operation took 2642.2 ms over 200 runs ------------ STAGE:2022-12-05 12:19:48 22172:22172 ActivityProfilerController.cpp:294] Completed Stage: Warm Up STAGE:2022-12-05 12:19:51 22172:22172 ActivityProfilerController.cpp:300] Completed Stage: Collection ------------------------- ------------ ------------ ------------ ------------ ------------ ------------ Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls ------------------------- ------------ ------------ ------------ ------------ ------------ ------------ aten::mm 95.21% 193.453ms 95.22% 193.474ms 967.370us 200 aten::matmul 1.85% 3.758ms 99.47% 202.094ms 1.010ms 200 aten::reshape 0.99% 2.003ms 1.81% 3.674ms 18.370us 200 aten::_reshape_alias 0.84% 1.698ms 0.84% 1.698ms 8.490us 200 aten::_unsafe_view 0.57% 1.161ms 0.57% 1.161ms 5.805us 200 aten::lift_fresh 0.53% 1.072ms 0.53% 1.072ms 5.360us 200 aten::resolve_conj 0.01% 21.000us 0.01% 21.000us 0.035us 600 cudaDeviceSynchronize 0.01% 14.000us 0.01% 14.000us 14.000us 1 ------------------------- ------------ ------------ ------------ ------------ ------------ ------------ Self CPU time total: 203.180ms use_ort=True, use_pt=True, disable_ort_spinning=False ort_threads=7, torch_threads=7, PT operation took 205.2 ms, ORT operation took 2778.5 ms over 200 runs ------------ STAGE:2022-12-05 12:20:03 22172:22172 ActivityProfilerController.cpp:294] Completed Stage: Warm Up STAGE:2022-12-05 12:20:05 22172:22172 ActivityProfilerController.cpp:300] Completed Stage: Collection ------------------------- ------------ ------------ ------------ ------------ ------------ ------------ Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls ------------------------- ------------ ------------ ------------ ------------ ------------ ------------ aten::mm 98.37% 244.519ms 98.37% 244.522ms 1.223ms 200 aten::matmul 0.64% 1.580ms 99.85% 248.196ms 1.241ms 200 aten::reshape 0.41% 1.018ms 0.68% 1.680ms 8.400us 200 aten::_reshape_alias 0.28% 696.000us 0.28% 696.000us 3.480us 200 aten::_unsafe_view 0.15% 380.000us 0.15% 380.000us 1.900us 200 aten::lift_fresh 0.15% 361.000us 0.15% 361.000us 1.805us 200 cudaDeviceSynchronize 0.00% 11.000us 0.00% 11.000us 11.000us 1 aten::resolve_conj 0.00% 3.000us 0.00% 3.000us 0.005us 600 ------------------------- ------------ ------------ ------------ ------------ ------------ ------------ Self CPU time total: 248.568ms use_ort=True, use_pt=True, disable_ort_spinning=True ort_threads=14, torch_threads=14, PT operation took 249.8 ms, ORT operation took 2117.6 ms over 200 runs ------------ STAGE:2022-12-05 12:20:17 22172:22172 ActivityProfilerController.cpp:294] Completed Stage: Warm Up STAGE:2022-12-05 12:20:20 22172:22172 ActivityProfilerController.cpp:300] Completed Stage: Collection ------------------------- ------------ ------------ ------------ ------------ ------------ ------------ Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls ------------------------- ------------ ------------ ------------ ------------ ------------ ------------ aten::mm 98.12% 175.821ms 98.12% 175.822ms 879.110us 200 aten::matmul 0.83% 1.492ms 99.87% 178.952ms 894.760us 200 aten::reshape 0.44% 787.000us 0.70% 1.263ms 6.315us 200 aten::_reshape_alias 0.30% 529.000us 0.30% 529.000us 2.645us 200 aten::_unsafe_view 0.18% 322.000us 0.18% 322.000us 1.610us 200 aten::lift_fresh 0.13% 229.000us 0.13% 229.000us 1.145us 200 cudaDeviceSynchronize 0.01% 11.000us 0.01% 11.000us 11.000us 1 aten::resolve_conj 0.00% 1.000us 0.00% 1.000us 0.002us 600 ------------------------- ------------ ------------ ------------ ------------ ------------ ------------ Self CPU time total: 179.192ms use_ort=True, use_pt=True, disable_ort_spinning=True ort_threads=7, torch_threads=7, PT operation took 180.4 ms, ORT operation took 2496.5 ms over 200 runs ------------ ```

TL;DR, for this specific model and specific pytorch operation:

fxmarty commented 1 year ago

@tianleiwu On which hardware/CPU did you reproduce?

Note: the issue can be reproduced on Intel(R) Xeon(R) Platinum 8255C CPU @ 2.50GHz