Closed xuzhao9 closed 1 year ago
Update: the result is available in an internal spreadsheet. However, the result is different from what we get from torchbench.py
. It could because we are compiling contexts everytime using We are still looking into what are the possible root causes.with torchdynamo.optimize(<backend>)
while torchbench.py
uses torchdynamo.run()
, which caches stuff and avoids re-compilation overheads.
@xuzhao9 are you sure .run()
is the difference? I would expect optimize()
and run()
to do the exact same thing unless there are recompiles (which there shouldn't be in most benchmarks).
@jansel You're right. Looks like the difference is not caused by optimize()
and run()
. Please ignore my previous update.
I managed to create a script to reproduce the problem on A100, please help verify:
import torch
import time
import gc
import numpy as np
from torchbenchmark import load_model_by_name
import argparse
def synchronize():
torch.cuda.synchronize()
def timed(model, example_inputs, times=1, dynamo=False):
synchronize()
gc.collect()
torch.manual_seed(1337)
t0 = time.time_ns()
# Dont collect outputs to correctly measure timing
if dynamo:
with torchdynamo.run():
result = model(*example_inputs)
else:
result = model(*example_inputs)
synchronize()
t1 = time.time_ns()
return (t1 - t0) / 1_000_000
def speedup_experiment(model, example_inputs, dynamo=False):
repeat = 100
timings = np.zeros((repeat, 2), np.float64)
for rep in range(repeat):
# interleave the runs to handle frequency scaling and load changes
timings[rep, 0] = timed(model, example_inputs)
if dynamo:
timings[rep, 1] = timed(model, example_inputs, dynamo=True)
median = np.median(timings, axis=0)
print(f"Eager Latency: {median[0]} ms")
if dynamo:
print(f"TorchDynamo Eager latency: {median[1]} ms")
print(f"speedup: {median[0]/median[1]} ")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--torchdynamo", action="store_true", help="load torchdynamo library")
args = parser.parse_args()
if args.torchdynamo:
import torchdynamo
optimize_ctx = torchdynamo.optimize("eager")
with optimize_ctx:
pass
Model = load_model_by_name("alexnet")
m = Model(device="cuda", test="eval", jit=False)
model, example_inputs = m.get_module()
speedup_experiment(model, example_inputs, dynamo=args.torchdynamo)
Usage: place the script in the torchbench
directory as runx.py
, and run two commands:
python runx.py
, my output:
Eager Latency: 0.934742 ms
python runx.py --torchdynamo
, my output:
Eager Latency: 1.1478039999999998 ms
TorchDynamo Eager latency: 1.1662355 ms
speedup: 0.9841957306221598
It shows, if we run the script with torchdynamo, the latency of eager mode also increases (from 0.9347 ms
to 1.14780 ms
).
Is the GPU frequency scaling due to load? Perhaps it is more lightly loaded when you only run 1 experiment compared to 2. That is why I interleave the runs in my test, so if the GPU frequency scales it will affect both sides of the experiment the same.
Running your script on an RTX 3090 I get (with the fan set to max):
(torchdynamo) empire:~/torchbenchmark$ for X in `seq 10`; do python runx.py; done
Eager Latency: 1.228479 ms
Eager Latency: 1.2281875 ms
Eager Latency: 1.2263385 ms
Eager Latency: 1.2311455 ms
Eager Latency: 1.228942 ms
Eager Latency: 1.2223555 ms
Eager Latency: 1.22241 ms
Eager Latency: 1.226303 ms
Eager Latency: 1.2278235 ms
Eager Latency: 1.2260115 ms
(torchdynamo) empire:~/torchbenchmark$ for X in `seq 10`; do python runx.py --torchdynamo; done
Eager Latency: 1.2290904999999999 ms
TorchDynamo Eager latency: 1.2378339999999999 ms
speedup: 0.992936451899043
Eager Latency: 1.2249225 ms
TorchDynamo Eager latency: 1.240448 ms
speedup: 0.987483957408936
Eager Latency: 1.2259004999999998 ms
TorchDynamo Eager latency: 1.2362005 ms
speedup: 0.9916680182543203
Eager Latency: 1.2307625 ms
TorchDynamo Eager latency: 1.23834 ms
speedup: 0.9938809212332639
Eager Latency: 1.2269335 ms
TorchDynamo Eager latency: 1.240305 ms
speedup: 0.9892191839910344
Eager Latency: 1.2323985 ms
TorchDynamo Eager latency: 1.244964 ms
speedup: 0.9899069370680598
Eager Latency: 1.2397464999999999 ms
TorchDynamo Eager latency: 1.256022 ms
speedup: 0.9870420263339336
Eager Latency: 1.230794 ms
TorchDynamo Eager latency: 1.2430765 ms
speedup: 0.9901192726272278
Eager Latency: 1.2288995 ms
TorchDynamo Eager latency: 1.2410505 ms
speedup: 0.990209101080093
Eager Latency: 1.231789 ms
TorchDynamo Eager latency: 1.2418645 ms
speedup: 0.9918867960232377
The current results looks reasonable for me. As Xu mentioned, it is not caused by the usage of torchdynamo.run() or torchdynamo.optimize() since we do not have recompile of model. @xuzhao9 just wondering why the initial result has large variance? What is the reason there?
@frank-wei Both @anijain2305 and I can reproduce this result on A100 (showing python runx.py
Eager is faster than python runx.py --torchdynamo
Eager). One reason of this discrepancy could be, since we don't have sudo access to the A100 instance, we can't pin the GPU frequency, etc to stabilize the latency number
Interesting that I can't reproduce it locally. A few ideas of things to try:
1) Can you reproduce your results on a CPU backend? Or is this GPU-only?
2) What if you replace dynamo with something else that uses the GPU. For example do:
model2 = copy.deepcopy(model)
then interleave runs of model
and model2
. You could also try interleaving it with something that does cpu work (for example model3 = model2.to("cpu")
).
Here is my test results on A100. Obervation is 1) In interleaving mode, difference is around less than 2%. 2) without interleaving, the difference is big. Feels like the gpu freq needs to be locked. Will test again after that.
(mypy38) [wwei6@devgpu005.ftw6 /data/users/wwei6/Work/torchbench] for X in seq 10
; do python runx.py; done
Eager Latency: 1.4733584999999998 ms
Eager Latency: 1.6835135 ms
Eager Latency: 1.4823665 ms
Eager Latency: 1.5355865 ms
Eager Latency: 1.510335 ms
Eager Latency: 1.49497 ms
Eager Latency: 1.4443485 ms
Eager Latency: 1.5397224999999999 ms
Eager Latency: 1.4612414999999999 ms
Eager Latency: 1.7127085 ms
for X in seq 10
; do python runx.py --torchdynamo; done
Eager Latency: 1.7416095 ms
TorchDynamo Eager latency: 1.760823 ms
speedup: 0.9890883410768715
Eager Latency: 1.7456610000000001 ms
TorchDynamo Eager latency: 1.7690625 ms
speedup: 0.9867718071012189
Eager Latency: 1.7898844999999999 ms
TorchDynamo Eager latency: 1.804828 ms
speedup: 0.9917202636483918
Eager Latency: 1.7428985 ms
TorchDynamo Eager latency: 1.763342 ms
speedup: 0.9884063896850412
Eager Latency: 1.7013310000000001 ms
TorchDynamo Eager latency: 1.7321525000000002 ms
speedup: 0.9822062433879234
Eager Latency: 1.7888145 ms
TorchDynamo Eager latency: 1.816144 ms
speedup: 0.984951909099719
Eager Latency: 2.013072 ms
TorchDynamo Eager latency: 2.0064089999999997 ms
speedup: 1.0033208583095474
Eager Latency: 1.9477324999999999 ms
TorchDynamo Eager latency: 1.9671699999999999 ms
speedup: 0.9901190542759396
Eager Latency: 1.68583 ms
TorchDynamo Eager latency: 1.718121 ms
speedup: 0.9812056310352996
Eager Latency: 1.9998985 ms
TorchDynamo Eager latency: 2.036448 ms
speedup: 0.9820523283678247
Locked the GPU freq to the highest with sudo nvidia-smi -ac 1215,1410 on A100. The results are similar as above. The big variance in test1 still exists.
Interesting that I can't reproduce it locally. A few ideas of things to try:
- Can you reproduce your results on a CPU backend? Or is this GPU-only?
- What if you replace dynamo with something else that uses the GPU. For example do:
model2 = copy.deepcopy(model)
then interleave runs of
model
andmodel2
. You could also try interleaving it with something that does cpu work (for examplemodel3 = model2.to("cpu")
).
I did two follow-up experiments.
Test on CPU:
$ python runx.py --device cpu
Eager Latency: 20.2229595 ms
Eager latency r2: 20.275337 ms
speedup: 0.9974166890542929
$ python runx.py --device cpu --torchdynamo
Eager Latency: 19.4026635 ms
TorchDynamo Eager latency: 19.417619000000002 ms
speedup: 0.9992297974329395
So the problem doesn't exist on CPU.
Use model2 = copy.deepcopy(model)
for GPU test round 2 (interleave two same GPU eager tests)
$ python runx.py
Eager Latency: 1.307467 ms
Eager latency r2: 1.312327 ms
speedup: 0.9962966547209651
$ python runx.py --torchdynamo
Eager Latency: 1.6921075 ms
TorchDynamo Eager latency: 1.714546 ms
speedup: 0.9869128620637767
When running with dynamo, the Eager Latency is longer than running without dynamo (1.69 ms vs. 1.30 ms). So the problem is reproducible.
The script I am using is:
import torch
import time
import gc
import copy
import numpy as np
from torchbenchmark import load_model_by_name
import argparse
DEVICE = "cuda"
def synchronize():
if DEVICE == "cuda":
torch.cuda.synchronize()
def timed(model, example_inputs, times=1, dynamo=False):
synchronize()
gc.collect()
torch.manual_seed(1337)
t0 = time.time_ns()
# Dont collect outputs to correctly measure timing
if dynamo:
with torchdynamo.run():
result = model(*example_inputs)
else:
result = model(*example_inputs)
synchronize()
t1 = time.time_ns()
return (t1 - t0) / 1_000_000
def speedup_experiment(model, example_inputs, dynamo=False):
repeat = 100
timings = np.zeros((repeat, 2), np.float64)
for rep in range(repeat):
# interleave the runs to handle frequency scaling and load changes
timings[rep, 0] = timed(model, example_inputs)
if dynamo:
timings[rep, 1] = timed(model, example_inputs, dynamo=True)
else:
model2 = copy.deepcopy(model)
timings[rep, 1] = timed(model2, example_inputs, dynamo=False)
median = np.median(timings, axis=0)
print(f"Eager Latency: {median[0]} ms")
if dynamo:
print(f"TorchDynamo Eager latency: {median[1]} ms")
print(f"speedup: {median[0]/median[1]} ")
else:
print(f"Eager latency r2: {median[1]} ms")
print(f"speedup: {median[0]/median[1]} ")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--torchdynamo", action="store_true", help="load torchdynamo library")
parser.add_argument("--device", choices=["cpu", "cuda"], default="cuda", help="specify device")
args = parser.parse_args()
DEVICE = args.device
if args.torchdynamo:
import torchdynamo
optimize_ctx = torchdynamo.optimize("eager")
with optimize_ctx:
pass
Model = load_model_by_name("alexnet")
m = Model(device=DEVICE, test="eval", jit=False)
model, example_inputs = m.get_module()
speedup_experiment(model, example_inputs, dynamo=args.torchdynamo)
For the second experiment I was suggesting not using TorchDynamo at all (just interleaving unrelated things both in eager mode). My hypothesis is this is a function of the measurement harness. Performance within the measured region is correlated to what happens outside the measured region.
For the second experiment I was suggesting not using TorchDynamo at all (just interleaving unrelated things both in eager mode). My hypothesis is this is a function of the measurement harness. Performance within the measured region is correlated to what happens outside the measured region.
Yes that was what I did in the second experiment (see the updated code for details).
Ah perfect, thanks! Just wanted to confirm.
I'm drawing a blank on what TorchDynamo could be doing to effect GPU-only (and only A100, not RTX 3090). TorchDynamo doesn't do any GPU work, it just runs the captured graph eagerly -- which should be literally identical kernels.
Perhaps there is some way to clear GPU caches...
@jansel I updated the benchmarking script and is able to reproduce the slowdown within the same process:
import torchdynamo
on demand (i.e., in the PyTorch eager run, do not import torchdynamo
)Result on V100 (pinned to the highest GPU and GPU memory frequency), run it a few times and looks pretty stable:
$ python runx.py --torchdynamo
Eager Latency: 1.2967605 ms
TorchDynamo Eager latency: 1.598749 ms
speedup: 0.8111094987393268
So the TorchDyamo eager backend is only 81% as fast as PyTorch eager mode. It looks like a single line import torchdynamo
slows down both PyTorch eager and TorchDynamo eager.
Did we do anything special in import torchdynamo
such that it slows down the entire GPU?
Updated script:
import torch
import time
import gc
import copy
import sys
import numpy as np
from torchbenchmark import load_model_by_name
import argparse
DEVICE = "cuda"
def synchronize():
if DEVICE == "cuda":
torch.cuda.synchronize()
def timed(model, example_inputs, times=1, dynamo=False):
torch.manual_seed(1337)
if dynamo:
import torchdynamo
gc.collect()
synchronize()
t0 = time.time_ns()
# Dont collect outputs to correctly measure timing
if dynamo:
# with torchdynamo.run():
# result = model(*example_inputs)
result = model(*example_inputs)
else:
result = model(*example_inputs)
synchronize()
t1 = time.time_ns()
return (t1 - t0) / 1_000_000
def speedup_experiment(model, example_inputs, model2, example_inputs2, dynamo=False):
repeat = 100
timings = np.zeros((repeat, 2), np.float64)
for rep in range(repeat):
# interleave the runs to handle frequency scaling and load changes
timings[rep, 0] = timed(model, example_inputs)
if dynamo:
timings[rep, 1] = timed(model2, example_inputs2, dynamo=True)
else:
timings[rep, 1] = timed(model2, example_inputs2, dynamo=False)
median = np.median(timings, axis=0)
print(f"Eager Latency: {median[0]} ms")
if dynamo:
print(f"TorchDynamo Eager latency: {median[1]} ms")
print(f"speedup: {median[0]/median[1]} ")
else:
print(f"Eager latency r2: {median[1]} ms")
print(f"speedup: {median[0]/median[1]} ")
def speedup_experiment2(model, example_inputs, model2, example_inputs2, dynamo=False):
repeat = 100
timings = np.zeros((repeat, 2), np.float64)
for rep in range(repeat):
timings[rep, 0] = timed(model, example_inputs)
for rep in range(repeat):
if dynamo:
timings[rep, 1] = timed(model2, example_inputs2, dynamo=True)
else:
timings[rep, 1] = timed(model2, example_inputs2, dynamo=False)
median = np.median(timings, axis=0)
print(f"Eager Latency: {median[0]} ms")
if dynamo:
print(f"TorchDynamo Eager latency: {median[1]} ms")
print(f"speedup: {median[0]/median[1]} ")
else:
print(f"Eager latency r2: {median[1]} ms")
print(f"speedup: {median[0]/median[1]} ")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--torchdynamo", action="store_true", help="load torchdynamo library")
parser.add_argument("--device", choices=["cpu", "cuda"], default="cuda", help="specify device")
args = parser.parse_args()
DEVICE = args.device
Model = load_model_by_name("alexnet")
m = Model(device=DEVICE, test="eval", jit=False)
m2 = Model(device=DEVICE, test="eval", jit=False)
model, example_inputs = m.get_module()
model2, example_inputs2 = m2.get_module()
# speedup_experiment(model, example_inputs, model2, example_inputs2, dynamo=args.torchdynamo)
speedup_experiment2(model, example_inputs, model2, example_inputs2, dynamo=args.torchdynamo)
@xuzhao9 I ran your latest script on an RTX 3090, and I can't reproduce your results:
$ python test2.py
Eager Latency: 1.2188915 ms
Eager latency r2: 1.2218585 ms
speedup: 0.9975717319149476
$ python test2.py --torchdynamo
Eager Latency: 1.2186050000000002 ms
TorchDynamo Eager latency: 1.246474 ms
speedup: 0.9776417317970532
So somehow it is specific to your machine. Perhaps we should try on a V100, with the latest version of CUDA, or with newer drivers.
Are you sure it is just the import torchdynamo
line? Does it still slow things down if you remove the torchdynamo.optimize("eager")
and torchdynamo.run()
?
@xuzhao9 I ran your latest script on an RTX 3090, and I can't reproduce your results:
$ python test2.py Eager Latency: 1.2188915 ms Eager latency r2: 1.2218585 ms speedup: 0.9975717319149476 $ python test2.py --torchdynamo Eager Latency: 1.2186050000000002 ms TorchDynamo Eager latency: 1.246474 ms speedup: 0.9776417317970532
So somehow it is specific to your machine. Perhaps we should try on a V100, with the latest version of CUDA, or with newer drivers.
Yes it looks like there is a problem within my dev environment. For A100, I am using Nvidia 450.119.03, CUDA 11.1, PyTorch 1.12.0.dev20220408+cu113 (I can't update Nvidia/CUDA version or pin frequency due to lack of sudo permission) For V100, I am using Nvidia 470.82.01, CUDA 11.3, PyTorch 1.12.0.dev20220331+cu113.
Are you sure it is just the
import torchdynamo
line? Does it still slow things down if you remove thetorchdynamo.optimize("eager")
andtorchdynamo.run()
?
Yes, I confirm if I remove torchdynamo.optimize("eager")
and torchdynamo.run()
, the slowdown still exists.
I am using:
CUDA Version: 11.6
Driver Version: 510.54
PyTorch bc512253d5fe718a029324c45f175a81b088facd
Just import torchdynamo
shouldn't have side effects. It doesn't even install the eval_frame handler. The walk of the torch.*
module hierarchy in allowed_functions happens lazily on first use, and not import time.
Could you try commenting out import lines in __init__.py
to try to narrow down which import exactly is causing the issue?
That is super odd. Perhaps something to do with memory allocation alignment... dunno, drawing a blank.
Tried on A100. NVIDIA-SMI 470.57.02 Driver Version: 470.57.02 CUDA Version: 11.4
for X in `seq 10`; do python runx3.py --torchdynamo; done
Eager Latency: 1.701516 ms
TorchDynamo Eager latency: 1.837139 ms
speedup: 0.9261770611804551
Eager Latency: 1.5440925 ms
TorchDynamo Eager latency: 1.8713935 ms
speedup: 0.8251030582290684
Eager Latency: 1.7131965 ms
TorchDynamo Eager latency: 1.9872285 ms
speedup: 0.8621034269587015
Eager Latency: 1.542824 ms
TorchDynamo Eager latency: 1.9405899999999998 ms
speedup: 0.7950283161306614
Eager Latency: 1.5197295 ms
TorchDynamo Eager latency: 1.858571 ms
speedup: 0.8176870832483666
Eager Latency: 1.5287305 ms
TorchDynamo Eager latency: 1.8407814999999998 ms
speedup: 0.8304790655490617
Eager Latency: 1.800862 ms
TorchDynamo Eager latency: 1.9163865 ms
speedup: 0.9397175361024511
Eager Latency: 1.491285 ms
TorchDynamo Eager latency: 1.8125605 ms
speedup: 0.8227504681912686
Eager Latency: 1.534363 ms
TorchDynamo Eager latency: 1.8803305 ms
speedup: 0.8160070796064841
Eager Latency: 1.8389465 ms
TorchDynamo Eager latency: 1.9739385 ms
speedup: 0.9316128643318928
After commented some imports in __init__.py
as follows:
from . import convert_frame
from . import resume_execution
# from .eval_frame import disable
from .eval_frame import optimize
# from .eval_frame import optimize_assert
# from .eval_frame import reset_code
# from .eval_frame import run
and all the imports in fx2trt
def fx2trt(subgraph, **kwargs):
if subgraph.will_tensorrt_barf():
# TensorRT fails violently with an abort() on this
return None
# import fx2trt_oss.tracer.acc_tracer.acc_tracer as acc_tracer
# from fx2trt_oss.fx.fx2trt import InputTensorSpec
# from fx2trt_oss.fx.fx2trt import TRTInterpreter
# from fx2trt_oss.fx.tools.trt_splitter import TRTSplitter
# from fx2trt_oss.fx.tools.trt_splitter import TRTSplitterSetting
# from fx2trt_oss.fx.trt_module import TRTModule
# from fx2trt_oss.fx.utils import LowerPrecision
The speedup is still hampered.
for X in `seq 3`; do python runx3.py --torchdynamo; done
Eager Latency: 1.5356960000000002 ms
TorchDynamo Eager latency: 2.0057774999999998 ms
speedup: 0.7656362682301503
Eager Latency: 1.738222 ms
TorchDynamo Eager latency: 1.8921785 ms
speedup: 0.9186353190251342
Eager Latency: 1.5379595 ms
TorchDynamo Eager latency: 1.8584245 ms
speedup: 0.8275609259348443
So what is the import line causing the issue?
So what is the import line causing the issue?
no conclusion yet.
Thanks to @jansel, I can confirm this line is the root cause. https://github.com/jansel/torchdynamo/blob/bf90b8cdbacf35944fa8c12185b1823dc5cb90bb/torchdynamo/skipfiles.py#L123
Specifically, adding five packages slows down the test: "networkx", "omegaconf", "onnx", "pandas", "sklearn", and of which sklearn
slows down 10%, and the other four slow 8-9% combined. This explains the total ~20% slowdown.
Interesting, that could explain why I can't reproduce the issue. I might have different versions of those 6 packages.
It is a bit alarming that importing random dependencies slows down PyTorch. Many of those packages don't even use the GPU.
A few theories from internal chat thread:
os.environ
(or other global state)Another finding: the slowdown of import sklearn
is related to gc.collect()
at https://github.com/facebookresearch/torchdynamo/blob/main/torchbench.py#L165. After removing this line, the regression from import sklearn
is gone.
Can you check if you'd still see the slowdown if you leave the gc.collect()
line in place, but add a couple untimed dry runs after this, before starting to collect actual timings?
At least part of the regression is due to the fact that PyTorch is build using Intel's OpenMP, while sklearn relies on GNU OpenMP runtime and it passes some environment variables to tell MKL to interop with it. Replacing import sklearn
with import _openmp_helpers
(See https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/utils/_openmp_helpers.pyx) results in 9% slowdown, but the module does nothing but imports libgomp.so
At least part of the regression is due to the fact that PyTorch is build using Intel's OpenMP, while sklearn relies on GNU OpenMP runtime and it passes some environment variables to tell MKL to interop with it. Replacing
import sklearn
withimport _openmp_helpers
(See https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/utils/_openmp_helpers.pyx) results in 9% slowdown, but the module does nothing but imports libgomp.so
This is a bit strange because I can only reproduce the issue on GPU, not CPU. How importing libgomp.so
impact CUDA performance?
Hi @malfet & @xuzhao9,
The public PyTorch release for linux is not built with Intel OpenMP.
The public PyTorch release for linux is not built with Intel OpenMP.
Hmm, that's a good point, it's only linked with iomp5 on MacOS, but not on Linux:
$ python -c "import torch;print(open('/proc/self/maps').read())"|grep omp
7f31e851d000-7f31e8527000 r--p 00000000 917:b543c 144115217301836474 /fsx/users/nshulga/conda/envs/py_3.9-torch-1.11-cuda-11.3/lib/libgomp.so.1.0.0
If one to increase batch size, then discrepancy between eager mode with and without sklearn import becomes negligible, using following script derived from one @xuzhao9 posted:
import torch
import time
import gc
import sys
import os
import numpy as np
import torchvision.models as models
import argparse
DEVICE = "cuda"
def synchronize():
if DEVICE == "cuda":
torch.cuda.synchronize()
def timed(model, example_inputs, times=1, dynamo=False):
torch.manual_seed(1337)
if dynamo:
import sklearn
gc.collect()
synchronize()
t0 = time.time_ns()
# Dont collect outputs to correctly measure timing
result = model(*example_inputs)
synchronize()
t1 = time.time_ns()
return (t1 - t0) / 1_000_000
def speedup_experiment(model, example_inputs, model2, example_inputs2):
repeat = 100
timings = np.zeros((repeat, 2), np.float64)
for rep in range(repeat):
timings[rep, 0] = timed(model, example_inputs)
for rep in range(repeat):
timings[rep, 1] = timed(model2, example_inputs2, dynamo=True)
median = np.median(timings, axis=0)
print(f"Eager Latency: {median[0]} ms")
print(f"sklearn Eager latency: {median[1]} ms")
print(f"speedup: {median[0]/median[1]} ")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--device", choices=["cpu", "cuda"], default="cuda", help="specify device")
parser.add_argument("--dtype", choices=["float16", "float32"], default="float16", help="specify dtype")
parser.add_argument("--sklearn", action="store_true", help="load torchdynamo library")
parser.add_argument("--batch-size", type=int, default=8, help="specify device")
args = parser.parse_args()
DEVICE = args.device
DTYPE = {"float16": torch.float16, "float32": torch.float32}[args.dtype]
batch_size = args.batch_size
model = models.alexnet(pretrained=True).to(device=DEVICE, dtype=DTYPE)
example_inputs = (torch.randn(batch_size, 3, 224, 224).to(device=DEVICE, dtype=DTYPE), )
model2 = models.alexnet(pretrained=True).to(device=DEVICE, dtype=DTYPE)
example_inputs2 = (torch.randn(batch_size, 3, 224, 224).to(device=DEVICE, dtype=DTYPE), )
model.eval()
model2.eval()
if args.sklearn:
import sklearn
speedup_experiment(model, example_inputs, model2, example_inputs2)
(py_3.9-torch-1.11-cuda-11.3) nshulga@dev-st-p3dn24xlarge-1:~$ python runx.py --batch-size=8
Eager Latency: 1.3937244999999998 ms
sklearn Eager latency: 1.569018 ms
speedup: 0.8882782096827441
(py_3.9-torch-1.11-cuda-11.3) nshulga@dev-st-p3dn24xlarge-1:~$ python runx.py --batch-size=16
Eager Latency: 1.3827145 ms
sklearn Eager latency: 1.571698 ms
speedup: 0.8797583886980832
(py_3.9-torch-1.11-cuda-11.3) nshulga@dev-st-p3dn24xlarge-1:~$ python runx.py --batch-size=32
Eager Latency: 2.0108475 ms
sklearn Eager latency: 2.0631779999999997 ms
speedup: 0.9746359742106596
(py_3.9-torch-1.11-cuda-11.3) nshulga@dev-st-p3dn24xlarge-1:~$ python runx.py --batch-size=64
Eager Latency: 3.54023 ms
sklearn Eager latency: 3.5917225000000004 ms
speedup: 0.9856635639306767
Hmm, another strange fact: why cudaStreamIsCapturing
dominates CUDA API time (even though we don't do any graph tracing):
$ nvprof python runx.py
==3411== NVPROF is profiling process 3411, command: python runx.py
Eager Latency: 1.8098640000000001 ms
sklearn Eager latency: 2.1530690000000003 ms
speedup: 0.8405973055206312
==3411== Profiling application: python runx.py
==3411== Profiling result:
Type Time(%) Time Calls Avg Min Max Name
GPU activities: 24.31% 51.150ms 34 1.5044ms 1.5990us 16.041ms [CUDA memcpy HtoD]
21.30% 44.810ms 400 112.03us 56.480us 175.17us void cutlass::Kernel<cutlass_70_wmma_tensorop_f16_s161616gemm_f16_16x16_64x2_tn_align8>(cutlass_70_wmma_tensorop_f16_s161616gemm_f16_16x16_64x2_tn_align8Params)
12.56% 26.437ms 200 132.19us 131.01us 133.38us volta_fp16_scudnn_fp16_128x64_relu_xregs_large_nn_v1
10.39% 21.866ms 400 54.664us 49.727us 59.999us void xmma_cudnn::gemm::kernel<xmma_cudnn::implicit_gemm::fprop_indexed::Kernel_traits<xmma_cudnn::Volta_hmma_fp32_traits, xmma_cudnn::Cta_tile<xmma_cudnn::Volta<int=0>, int=64, int=128, int=32, int=2, int=2, int=1, int=1>, xmma_cudnn::implicit_gemm::fprop_indexed::Gmem_tile_a_t<xmma_cudnn::Volta_hmma_fp32_traits, xmma_cudnn::Cta_tile<xmma_cudnn::Volta<int=0>, int=64, int=128, int=32, int=2, int=2, int=1, int=1>, xmma_cudnn::implicit_gemm::Input_related<int=0, int=0, int=0, bool=0>, int=16, bool=0, xmma_cudnn::implicit_gemm::fprop_indexed::Gmem_tile_base_a<xmma_cudnn::Volta_hmma_fp32_traits, xmma_cudnn::Cta_tile<xmma_cudnn::Volta<int=0>, int=64, int=128, int=32, int=2, int=2, int=1, int=1>, xmma_cudnn::implicit_gemm::Input_related<int=0, int=0, int=0, bool=0>, int=16, xmma_cudnn::Row, int=32, int=64>>, xmma_cudnn::implicit_gemm::fprop_indexed::Gmem_tile_c_t<xmma_cudnn::Volta_hmma_fp32_traits, xmma_cudnn::Cta_tile<xmma_cudnn::Volta<int=0>, int=64, int=128, int=32, int=2, int=2, int=1, int=1>, int=16, xmma_cudnn::Fragment_c<xmma_cudnn::Volta_hmma_fp32_traits, xmma_cudnn::Cta_tile<xmma_cudnn::Volta<int=0>, int=64, int=128, int=32, int=2, int=2, int=1, int=1>, bool=0>>, xmma_cudnn::implicit_gemm::Input_related<int=0, int=0, int=0, bool=0>, int=1>>(xmma_cudnn::Volta_hmma_fp32_traitsParams)
6.50% 13.684ms 200 68.418us 67.647us 69.376us volta_fp16_s884cudnn_fp16_256x64_ldg8_relu_f2f_exp_small_nhwc2nchw_tn_v1
5.23% 11.009ms 200 55.042us 49.920us 57.279us sm70_xmma_fprop_implicit_gemm_f16f16_f16f32_f32_nhwckrsc_nhwc_tilesize64x64x64_stage1_warpsize2x2x1_g1_tensor8x8x4_kernel
4.78% 10.057ms 1600 6.2850us 2.5280us 13.376us void cudnn::ops::nchwToNhwcKernel<__half, __half, float, bool=0, bool=1, cudnnKernelDataType_t=0>(cudnn::ops::nchw2nhwc_params_t<float>, __half const *, __half*)
3.20% 6.7238ms 1000 6.7230us 4.2880us 12.064us _ZN2at6native18elementwise_kernelILi128ELi4EZNS0_15gpu_kernel_implINS0_13BinaryFunctorIN3c104HalfES5_S5_NS0_10AddFunctorIfEEEEEEvRNS_18TensorIteratorBaseERKT_EUliE_EEviT1_
2.52% 5.3069ms 600 8.8440us 4.1280us 13.760us void at::native::_GLOBAL__N__63_tmpxft_0000631e_00000000_21_DilatedMaxPool2d_compute_86_cpp1_ii_6258b574::max_pool_forward_nchw<c10::Half, c10::Half>(int, c10::Half const *, int, int, int, int, int, int, int, int, int, int, int, int, int, at::native::_GLOBAL__N__63_tmpxft_0000631e_00000000_21_DilatedMaxPool2d_compute_86_cpp1_ii_6258b574::max_pool_forward_nchw<c10::Half, c10::Half>*, long*)
2.31% 4.8502ms 1400 3.4640us 1.6640us 8.8000us _ZN2at6native29vectorized_elementwise_kernelILi4EZZZNS0_84_GLOBAL__N__60_tmpxft_00007377_00000000_21_TensorCompare_compute_86_cpp1_ii_d0af11f719launch_clamp_scalarERNS_18TensorIteratorBaseEN3c106ScalarES6_NS0_6detail11ClampLimitsEENKUlvE_clEvENKUlvE14_clEvEUlNS5_4HalfEE_NS_6detail5ArrayIPcLi2EEEEEviT0_T1_
2.03% 4.2620ms 200 21.309us 20.544us 22.496us void cutlass::Kernel<cutlass_70_wmma_tensorop_s161616gemm_f16_32x32_64x2_tn_align8>(cutlass_70_wmma_tensorop_s161616gemm_f16_32x32_64x2_tn_align8Params)
1.54% 3.2403ms 200 16.201us 15.904us 16.576us void at::native::_GLOBAL__N__69_tmpxft_00005dd7_00000000_21_AdaptiveAveragePooling_compute_86_cpp1_ii_ef17039c::adaptive_average_pool<c10::Half>(c10::Half*, c10::Half, int, int, int, int, long, long, long)
1.11% 2.3282ms 600 3.8800us 3.0710us 5.5360us _ZN2at6native18elementwise_kernelILi128ELi4EZNS0_15gpu_kernel_implIZZZNS0_75_GLOBAL__N__51_tmpxft_000061c0_00000000_21_Copy_compute_86_cpp1_ii_ddc55cb923direct_copy_kernel_cudaERNS_18TensorIteratorBaseEENKUlvE0_clEvENKUlvE18_clEvEUlN3c104HalfEE_EEvS5_RKT_EUliE_EEviT1_
1.06% 2.2252ms 600 3.7080us 3.0400us 5.1840us void cudnn::ops::nhwcToNchwKernel<__half, __half, float, bool=1, bool=0, cudnnKernelDataType_t=0>(cudnn::ops::nhwc2nchw_params_t<float>, __half const *, __half*)
0.47% 990.46us 604 1.6390us 1.5680us 2.3360us [CUDA memset]
0.36% 753.02us 400 1.8820us 1.6950us 2.5920us void cask_cudnn::computeOffsetsKernel<bool=0, bool=0>(cask_cudnn::ComputeOffsetsParams)
0.35% 726.78us 200 3.6330us 3.1680us 4.4800us void splitKreduce_kernel<float, __half, float, __half, bool=1, bool=0>(cublasSplitKParams<float>, float const *, __half const *, __half*, float const *, float const *, __half const *, void*, long, float*, int*)
API calls: 80.27% 5.01770s 210 23.894ms 1.0040us 5.01737s cudaStreamIsCapturing
16.75% 1.04700s 4 261.75ms 1.6460us 653.52ms cudaFree
1.14% 71.470ms 8200 8.7150us 5.0050us 493.74us cudaLaunchKernel
0.86% 53.921ms 34 1.5859ms 7.1320us 16.178ms cudaMemcpyAsync
0.26% 16.008ms 39525 405ns 268ns 499.95us cudaGetDevice
0.21% 13.376ms 400 33.440us 5.1250us 92.660us cudaDeviceSynchronize
0.10% 6.2973ms 20 314.87us 3.9500us 762.77us cudaMalloc
0.09% 5.5990ms 604 9.2690us 6.0650us 87.593us cudaMemsetAsync
0.05% 3.1450ms 1200 2.6200us 1.2400us 239.89us cudaEventRecord
0.04% 2.6237ms 297 8.8340us 137ns 421.69us cuDeviceGetAttribute
0.04% 2.4903ms 3 830.10us 798.66us 877.06us cudaGetDeviceProperties
0.03% 2.0540ms 3 684.65us 664.22us 698.85us cuDeviceTotalMem
0.03% 2.0162ms 1200 1.6800us 432ns 22.153us cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags
0.03% 1.7191ms 8010 214ns 120ns 16.468us cudaGetLastError
0.02% 1.4726ms 34 43.310us 3.5840us 100.27us cudaStreamSynchronize
0.02% 1.1446ms 1 1.1446ms 1.1446ms 1.1446ms cudaHostAlloc
0.02% 1.0711ms 200 5.3550us 4.5650us 17.565us cudaEventQuery
0.01% 613.70us 800 767ns 308ns 3.3600us cudaStreamGetCaptureInfo
0.01% 379.60us 4 94.898us 1.9380us 371.47us cudaStreamCreateWithPriority
0.01% 315.70us 3 105.23us 101.60us 111.98us cuDeviceGetName
0.00% 117.93us 203 580ns 174ns 1.1070us cuDevicePrimaryCtxGetState
0.00% 52.471us 8 6.5580us 1.9580us 35.341us cudaStreamCreateWithFlags
0.00% 49.092us 48 1.0220us 440ns 12.521us cudaEventCreateWithFlags
0.00% 46.053us 48 959ns 603ns 3.6960us cudaFuncSetAttribute
0.00% 17.310us 32 540ns 277ns 3.3390us cudaDeviceGetAttribute
0.00% 5.4580us 1 5.4580us 5.4580us 5.4580us cuDeviceGetPCIBusId
0.00% 5.3510us 2 2.6750us 2.3960us 2.9550us cuInit
0.00% 3.2920us 5 658ns 172ns 2.1900us cuDeviceGetCount
0.00% 3.0950us 1 3.0950us 3.0950us 3.0950us cudaGetSymbolAddress
0.00% 3.0090us 4 752ns 214ns 2.2100us cudaGetDeviceCount
0.00% 2.0980us 1 2.0980us 2.0980us 2.0980us cudaDriverGetVersion
0.00% 2.0660us 1 2.0660us 2.0660us 2.0660us cudaHostGetDevicePointer
0.00% 1.4250us 1 1.4250us 1.4250us 1.4250us cudaDeviceGetStreamPriorityRange
0.00% 1.2240us 4 306ns 140ns 615ns cuDeviceGet
0.00% 832ns 3 277ns 243ns 310ns cuDeviceGetUuid
0.00% 715ns 2 357ns 282ns 433ns cuDriverGetVersion
Interestingly enough, this benchmark is much faster using CUDA-10.2 (and discrepancy is much smaller, because cudaStreamIsCapturing is CUDA-11.x+ only API call):
$ nvprof python runx.py
==5631== NVPROF is profiling process 5631, command: python runx.py
Eager Latency: 1.2893634999999999 ms
sklearn Eager latency: 1.362261 ms
speedup: 0.9464878609899278
==5631== Profiling application: python runx.py
==5631== Profiling result:
Type Time(%) Time Calls Avg Min Max Name
GPU activities: 30.20% 72.399ms 600 120.66us 18.368us 274.33us volta_fp16_s884gemm_fp16_128x64_ldg8_f2f_tn
21.22% 50.865ms 36 1.4129ms 1.5360us 15.958ms [CUDA memcpy HtoD]
10.96% 26.274ms 200 131.37us 129.98us 132.64us volta_fp16_scudnn_fp16_128x64_relu_xregs_large_nn_v1
8.38% 20.080ms 400 50.199us 48.543us 52.703us Volta_hmma_implicit_gemm_fprop_fp32_nhwc_64x32x64x1_1x3x3x0x1
6.52% 15.640ms 200 78.198us 77.280us 79.264us Volta_hmma_implicit_gemm_fprop_fp32_nhwc_128x128x64x1_1x3x3x0x1
5.99% 14.355ms 200 71.776us 70.687us 73.023us Volta_hmma_implicit_gemm_fprop_fp32_nhwc_256x64x64x1_1x5x5x0x1
4.38% 10.512ms 1600 6.5690us 2.6880us 13.824us void nchwToNhwcKernel<__half, __half, float, bool=1, bool=0>(int, int, int, int, __half const *, __half*, float, float)
2.87% 6.8697ms 1000 6.8690us 4.3520us 12.064us _ZN2at6native18elementwise_kernelILi128ELi4EZNS0_15gpu_kernel_implINS0_13BinaryFunctorIN3c104HalfES5_S5_NS0_10AddFunctorIfEEEEEEvRNS_18TensorIteratorBaseERKT_EUliE_EEviT1_
2.28% 5.4710ms 600 9.1180us 3.9680us 14.240us void at::native::_GLOBAL__N__63_tmpxft_000043ff_00000000_16_DilatedMaxPool2d_compute_75_cpp1_ii_6258b574::max_pool_forward_nchw<c10::Half, c10::Half>(int, c10::Half const *, int, int, int, int, int, int, int, int, int, int, int, int, int, at::native::_GLOBAL__N__63_tmpxft_000043ff_00000000_16_DilatedMaxPool2d_compute_75_cpp1_ii_6258b574::max_pool_forward_nchw<c10::Half, c10::Half>*, long*)
2.08% 4.9898ms 1400 3.5640us 1.7600us 11.616us _ZN2at6native29vectorized_elementwise_kernelILi4EZZZNS0_84_GLOBAL__N__60_tmpxft_0000507a_00000000_16_TensorCompare_compute_75_cpp1_ii_d0af11f719launch_clamp_scalarERNS_18TensorIteratorBaseEN3c106ScalarES6_NS0_6detail11ClampLimitsEENKUlvE_clEvENKUlvE14_clEvEUlNS5_4HalfEE_NS_6detail5ArrayIPcLi2EEEEEviT0_T1_
1.47% 3.5253ms 800 4.4060us 3.0400us 9.1840us void nhwcToNchwKernel<__half, __half, float, bool=1, bool=0>(int, int, int, int, __half const *, __half*, float, float)
1.40% 3.3613ms 200 16.806us 16.479us 17.472us void at::native::_GLOBAL__N__69_tmpxft_00003fb1_00000000_16_AdaptiveAveragePooling_compute_75_cpp1_ii_ef17039c::adaptive_average_pool<c10::Half>(c10::Half*, c10::Half, int, int, int, int, long, long, long)
1.00% 2.4039ms 600 4.0060us 2.7830us 12.800us void splitKreduce_kernel<__half, __half, float>(cublasSplitKParams<float>, __half const *, __half const *, __half*, float const *, float const *)
0.96% 2.2979ms 600 3.8290us 3.1030us 8.6080us _ZN2at6native18elementwise_kernelILi128ELi4EZNS0_15gpu_kernel_implIZZZNS0_75_GLOBAL__N__51_tmpxft_00004315_00000000_16_Copy_compute_75_cpp1_ii_ddc55cb923direct_copy_kernel_cudaERNS_18TensorIteratorBaseEENKUlvE0_clEvENKUlvE18_clEvEUlN3c104HalfEE_EEvS5_RKT_EUliE_EEviT1_
0.15% 361.76us 200 1.8080us 1.4720us 2.1120us cudnn::gemm::computeOffsetsKernel(cudnn::gemm::ComputeOffsetsParams)
0.14% 327.17us 204 1.6030us 1.5680us 2.2720us [CUDA memset]
API calls: 87.89% 4.59159s 22 208.71ms 4.3320us 4.58564s cudaMalloc
8.27% 432.16ms 7 61.737ms 827ns 262.46ms cudaFree
1.29% 67.310ms 8600 7.8260us 4.4680us 497.92us cudaLaunchKernel
1.03% 53.634ms 34 1.5775ms 7.8300us 16.102ms cudaMemcpyAsync
0.76% 39.844ms 400 99.610us 7.2250us 217.62us cudaDeviceSynchronize
0.27% 14.231ms 39109 363ns 269ns 16.881us cudaGetDevice
0.10% 5.3753ms 285 18.860us 135ns 2.1885ms cuDeviceGetAttribute
0.07% 3.4358ms 1600 2.1470us 1.0630us 82.579us cudaEventRecord
0.04% 2.3343ms 1 2.3343ms 2.3343ms 2.3343ms cudaHostAlloc
0.04% 2.1773ms 204 10.672us 7.9260us 49.996us cudaMemsetAsync
0.04% 2.1516ms 3 717.20us 667.03us 812.66us cuDeviceTotalMem
0.04% 1.9088ms 9410 202ns 116ns 12.480us cudaGetLastError
0.03% 1.7035ms 4 425.88us 4.8480us 1.6820ms cudaStreamCreateWithPriority
0.03% 1.6660ms 600 2.7760us 1.2370us 18.910us cudaEventQuery
0.03% 1.6257ms 2 812.87us 807.85us 817.89us cudaGetDeviceProperties
0.03% 1.4117ms 34 41.520us 8.0820us 101.20us cudaStreamSynchronize
0.02% 940.24us 600 1.5670us 836ns 83.960us cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags
0.01% 305.12us 3 101.71us 93.610us 116.02us cuDeviceGetName
0.00% 92.978us 180 516ns 355ns 2.2030us cudaFuncSetAttribute
0.00% 60.980us 8 7.6220us 1.6930us 44.213us cudaStreamCreateWithFlags
0.00% 33.055us 48 688ns 425ns 2.1170us cudaEventCreateWithFlags
0.00% 30.779us 2 15.389us 14.331us 16.448us cudaMemcpy
0.00% 20.600us 40 515ns 304ns 3.3830us cudaDeviceGetAttribute
0.00% 5.2590us 1 5.2590us 5.2590us 5.2590us cudaHostGetDevicePointer
0.00% 5.2450us 2 2.6220us 2.2550us 2.9900us cuInit
0.00% 3.5710us 1 3.5710us 3.5710us 3.5710us cuDeviceGetPCIBusId
0.00% 2.9290us 4 732ns 177ns 2.2950us cudaGetDeviceCount
0.00% 2.2850us 5 457ns 171ns 1.1220us cuDeviceGetCount
0.00% 1.4220us 3 474ns 297ns 804ns cuDeviceGetUuid
0.00% 1.3440us 4 336ns 152ns 564ns cuDeviceGet
0.00% 925ns 2 462ns 429ns 496ns cuDriverGetVersion
cudaStreamIsCapturing is probably called during memory allocations that don't hit in cache. We shouldn't be benchmarking runs that have a lot of uncached mallocs - that's probably why removing gc.collect()
helps. In the steady state, we expect memory requests to be mostly served from the cached memory, and not call cudaMalloc.
cudaStreamIsCapturing is probably called during memory allocations that don't hit in cache. We shouldn't be benchmarking runs that have a lot of uncached mallocs - that's probably why removing
gc.collect()
helps. In the steady state, we expect memory requests to be mostly served from the cached memory, and not call cudaMalloc.
It would have been true, if number of cudaMalloc() calls would much number of cudaStreamIsCapturing
, but it is 10x different: cudaMalloc
is called 20 times, while cudaStreamIsCapturing
200+ times (i.e. at least once for every model(*inputs)
call)
The previous bug has been fixed by https://github.com/facebookresearch/torchdynamo/pull/153. Yet another problem shows up, the reproducing script:
import torch
import time
import gc
import copy
import sys
import numpy as np
from torchbenchmark import load_model_by_name
import argparse
DEVICE = "cuda"
def synchronize():
if DEVICE == "cuda":
torch.cuda.synchronize()
def timed(model, example_inputs, times=1, dynamo=None):
torch.manual_seed(1337)
if True:
import torchdynamo
optimize_ctx = torchdynamo.optimize("eager")
with optimize_ctx:
pass
# print(sys.modules.keys())
synchronize()
t0 = time.time_ns()
# Dont collect outputs to correctly measure timing
if dynamo == "d1":
with torchdynamo.optimize("eager"):
result = model(*example_inputs)
else:
with torchdynamo.run():
result = model(*example_inputs)
synchronize()
t1 = time.time_ns()
return (t1 - t0) / 1_000_000
def speedup_experiment(model, example_inputs, model2, example_inputs2):
repeat = 100
timings = np.zeros((repeat, 2), np.float64)
for rep in range(repeat):
timings[rep, 0] = timed(model, example_inputs)
for rep in range(repeat):
timings[rep, 1] = timed(model2, example_inputs2, dynamo="d1")
median = np.median(timings, axis=0)
print(f"Eager Latency: {median[0]} ms")
print(f"TorchDynamo Eager latency: {median[1]} ms")
print(f"speedup: {median[0]/median[1]} ")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--device", choices=["cpu", "cuda"], default="cuda", help="specify device")
args = parser.parse_args()
DEVICE = args.device
Model = load_model_by_name("alexnet")
m = Model(device=DEVICE, test="eval", jit=False)
m2 = Model(device=DEVICE, test="eval", jit=False)
model, example_inputs = m.get_module()
model2, example_inputs2 = m2.get_module()
speedup_experiment(model, example_inputs, model2, example_inputs2)
run python runx.py
returns:
Eager Latency: 1.0757485 ms
TorchDynamo Eager latency: 1.183398 ms
speedup: 0.9090335626729131
So torchdynamo.optimize("eager")
is 10% slower than torchdynamo.run()
.
Update: the slowdown seems to be caused by https://github.com/facebookresearch/torchdynamo/blob/1e131a6e81cf2253b3672caa1f7a72629a472989/torchdynamo/mutation_guard.py#L88
Closing as stale.
The CI is ready, working on understanding the results - quite different from what we get from
torchbench.py