pytorch / torchdynamo

A Python-level JIT compiler designed to make unmodified PyTorch programs faster.
BSD 3-Clause "New" or "Revised" License
1.01k stars 123 forks source link

Get TorchBench result on Dynamo inference TRT #107

Closed xuzhao9 closed 1 year ago

xuzhao9 commented 2 years ago

The CI is ready, working on understanding the results - quite different from what we get from torchbench.py

xuzhao9 commented 2 years ago

Update: the result is available in an internal spreadsheet. However, the result is different from what we get from torchbench.py. It could because we are compiling contexts everytime using with torchdynamo.optimize(<backend>) while torchbench.py uses torchdynamo.run(), which caches stuff and avoids re-compilation overheads. We are still looking into what are the possible root causes.

jansel commented 2 years ago

@xuzhao9 are you sure .run() is the difference? I would expect optimize() and run() to do the exact same thing unless there are recompiles (which there shouldn't be in most benchmarks).

xuzhao9 commented 2 years ago

@jansel You're right. Looks like the difference is not caused by optimize() and run(). Please ignore my previous update. I managed to create a script to reproduce the problem on A100, please help verify:

import torch
import time
import gc
import numpy as np
from torchbenchmark import load_model_by_name
import argparse

def synchronize():
    torch.cuda.synchronize()

def timed(model, example_inputs, times=1, dynamo=False):
    synchronize()
    gc.collect()
    torch.manual_seed(1337)
    t0 = time.time_ns()
    # Dont collect outputs to correctly measure timing
    if dynamo:
        with torchdynamo.run():
            result = model(*example_inputs)
    else:
        result = model(*example_inputs)
    synchronize()
    t1 = time.time_ns()
    return (t1 - t0) / 1_000_000

def speedup_experiment(model, example_inputs, dynamo=False):
    repeat = 100
    timings = np.zeros((repeat, 2), np.float64)
    for rep in range(repeat):
        # interleave the runs to handle frequency scaling and load changes
        timings[rep, 0] = timed(model, example_inputs)
        if dynamo:
            timings[rep, 1] = timed(model, example_inputs, dynamo=True)
    median = np.median(timings, axis=0)
    print(f"Eager Latency: {median[0]} ms")
    if dynamo:
        print(f"TorchDynamo Eager latency: {median[1]} ms")
        print(f"speedup: {median[0]/median[1]} ")

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--torchdynamo", action="store_true", help="load torchdynamo library")
    args = parser.parse_args()
    if args.torchdynamo:
        import torchdynamo
        optimize_ctx = torchdynamo.optimize("eager")
        with optimize_ctx:
            pass
    Model = load_model_by_name("alexnet")
    m = Model(device="cuda", test="eval", jit=False)
    model, example_inputs = m.get_module()
    speedup_experiment(model, example_inputs, dynamo=args.torchdynamo)

Usage: place the script in the torchbench directory as runx.py, and run two commands:

  1. python runx.py, my output: Eager Latency: 0.934742 ms
  2. python runx.py --torchdynamo, my output:
    Eager Latency: 1.1478039999999998 ms
    TorchDynamo Eager latency: 1.1662355 ms
    speedup: 0.9841957306221598

It shows, if we run the script with torchdynamo, the latency of eager mode also increases (from 0.9347 ms to 1.14780 ms).

jansel commented 2 years ago

Is the GPU frequency scaling due to load? Perhaps it is more lightly loaded when you only run 1 experiment compared to 2. That is why I interleave the runs in my test, so if the GPU frequency scales it will affect both sides of the experiment the same.

Running your script on an RTX 3090 I get (with the fan set to max):

(torchdynamo) empire:~/torchbenchmark$ for X in `seq 10`; do python runx.py; done
Eager Latency: 1.228479 ms
Eager Latency: 1.2281875 ms
Eager Latency: 1.2263385 ms
Eager Latency: 1.2311455 ms
Eager Latency: 1.228942 ms
Eager Latency: 1.2223555 ms
Eager Latency: 1.22241 ms
Eager Latency: 1.226303 ms
Eager Latency: 1.2278235 ms
Eager Latency: 1.2260115 ms
(torchdynamo) empire:~/torchbenchmark$ for X in `seq 10`; do python runx.py --torchdynamo; done
Eager Latency: 1.2290904999999999 ms
TorchDynamo Eager latency: 1.2378339999999999 ms
speedup: 0.992936451899043 
Eager Latency: 1.2249225 ms
TorchDynamo Eager latency: 1.240448 ms
speedup: 0.987483957408936 
Eager Latency: 1.2259004999999998 ms
TorchDynamo Eager latency: 1.2362005 ms
speedup: 0.9916680182543203 
Eager Latency: 1.2307625 ms
TorchDynamo Eager latency: 1.23834 ms
speedup: 0.9938809212332639 
Eager Latency: 1.2269335 ms
TorchDynamo Eager latency: 1.240305 ms
speedup: 0.9892191839910344 
Eager Latency: 1.2323985 ms
TorchDynamo Eager latency: 1.244964 ms
speedup: 0.9899069370680598 
Eager Latency: 1.2397464999999999 ms
TorchDynamo Eager latency: 1.256022 ms
speedup: 0.9870420263339336 
Eager Latency: 1.230794 ms
TorchDynamo Eager latency: 1.2430765 ms
speedup: 0.9901192726272278 
Eager Latency: 1.2288995 ms
TorchDynamo Eager latency: 1.2410505 ms
speedup: 0.990209101080093 
Eager Latency: 1.231789 ms
TorchDynamo Eager latency: 1.2418645 ms
speedup: 0.9918867960232377 
frank-wei commented 2 years ago

The current results looks reasonable for me. As Xu mentioned, it is not caused by the usage of torchdynamo.run() or torchdynamo.optimize() since we do not have recompile of model. @xuzhao9 just wondering why the initial result has large variance? What is the reason there?

xuzhao9 commented 2 years ago

@frank-wei Both @anijain2305 and I can reproduce this result on A100 (showing python runx.py Eager is faster than python runx.py --torchdynamo Eager). One reason of this discrepancy could be, since we don't have sudo access to the A100 instance, we can't pin the GPU frequency, etc to stabilize the latency number

jansel commented 2 years ago

Interesting that I can't reproduce it locally. A few ideas of things to try:

1) Can you reproduce your results on a CPU backend? Or is this GPU-only?

2) What if you replace dynamo with something else that uses the GPU. For example do:

model2 = copy.deepcopy(model)

then interleave runs of model and model2. You could also try interleaving it with something that does cpu work (for example model3 = model2.to("cpu")).

frank-wei commented 2 years ago

Here is my test results on A100. Obervation is 1) In interleaving mode, difference is around less than 2%. 2) without interleaving, the difference is big. Feels like the gpu freq needs to be locked. Will test again after that.

Test1 for alexnet without running interleaving

(mypy38) [wwei6@devgpu005.ftw6 /data/users/wwei6/Work/torchbench] for X in seq 10; do python runx.py; done Eager Latency: 1.4733584999999998 ms Eager Latency: 1.6835135 ms Eager Latency: 1.4823665 ms Eager Latency: 1.5355865 ms Eager Latency: 1.510335 ms Eager Latency: 1.49497 ms Eager Latency: 1.4443485 ms Eager Latency: 1.5397224999999999 ms Eager Latency: 1.4612414999999999 ms Eager Latency: 1.7127085 ms

Test2 for alexnet

for X in seq 10; do python runx.py --torchdynamo; done Eager Latency: 1.7416095 ms TorchDynamo Eager latency: 1.760823 ms speedup: 0.9890883410768715 Eager Latency: 1.7456610000000001 ms TorchDynamo Eager latency: 1.7690625 ms speedup: 0.9867718071012189 Eager Latency: 1.7898844999999999 ms TorchDynamo Eager latency: 1.804828 ms speedup: 0.9917202636483918 Eager Latency: 1.7428985 ms TorchDynamo Eager latency: 1.763342 ms speedup: 0.9884063896850412 Eager Latency: 1.7013310000000001 ms TorchDynamo Eager latency: 1.7321525000000002 ms speedup: 0.9822062433879234 Eager Latency: 1.7888145 ms TorchDynamo Eager latency: 1.816144 ms speedup: 0.984951909099719 Eager Latency: 2.013072 ms TorchDynamo Eager latency: 2.0064089999999997 ms speedup: 1.0033208583095474 Eager Latency: 1.9477324999999999 ms TorchDynamo Eager latency: 1.9671699999999999 ms speedup: 0.9901190542759396 Eager Latency: 1.68583 ms TorchDynamo Eager latency: 1.718121 ms speedup: 0.9812056310352996 Eager Latency: 1.9998985 ms TorchDynamo Eager latency: 2.036448 ms speedup: 0.9820523283678247

frank-wei commented 2 years ago

Locked the GPU freq to the highest with sudo nvidia-smi -ac 1215,1410 on A100. The results are similar as above. The big variance in test1 still exists.

xuzhao9 commented 2 years ago

Interesting that I can't reproduce it locally. A few ideas of things to try:

  1. Can you reproduce your results on a CPU backend? Or is this GPU-only?
  2. What if you replace dynamo with something else that uses the GPU. For example do:
model2 = copy.deepcopy(model)

then interleave runs of model and model2. You could also try interleaving it with something that does cpu work (for example model3 = model2.to("cpu")).

I did two follow-up experiments.

  1. Test on CPU:

    $ python runx.py --device cpu
    Eager Latency: 20.2229595 ms
    Eager latency r2: 20.275337 ms
    speedup: 0.9974166890542929
    $ python runx.py --device cpu --torchdynamo
    Eager Latency: 19.4026635 ms
    TorchDynamo Eager latency: 19.417619000000002 ms
    speedup: 0.9992297974329395

    So the problem doesn't exist on CPU.

  2. Use model2 = copy.deepcopy(model) for GPU test round 2 (interleave two same GPU eager tests)

    $ python runx.py
    Eager Latency: 1.307467 ms
    Eager latency r2: 1.312327 ms
    speedup: 0.9962966547209651
    $ python runx.py --torchdynamo
    Eager Latency: 1.6921075 ms
    TorchDynamo Eager latency: 1.714546 ms
    speedup: 0.9869128620637767

When running with dynamo, the Eager Latency is longer than running without dynamo (1.69 ms vs. 1.30 ms). So the problem is reproducible.

The script I am using is:

import torch
import time
import gc
import copy
import numpy as np
from torchbenchmark import load_model_by_name
import argparse

DEVICE = "cuda"

def synchronize():
    if DEVICE == "cuda":
        torch.cuda.synchronize()

def timed(model, example_inputs, times=1, dynamo=False):
    synchronize()
    gc.collect()
    torch.manual_seed(1337)
    t0 = time.time_ns()
    # Dont collect outputs to correctly measure timing
    if dynamo:
        with torchdynamo.run():
            result = model(*example_inputs)
    else:
        result = model(*example_inputs)
    synchronize()
    t1 = time.time_ns()
    return (t1 - t0) / 1_000_000

def speedup_experiment(model, example_inputs, dynamo=False):
    repeat = 100
    timings = np.zeros((repeat, 2), np.float64)
    for rep in range(repeat):
        # interleave the runs to handle frequency scaling and load changes
        timings[rep, 0] = timed(model, example_inputs)
        if dynamo:
            timings[rep, 1] = timed(model, example_inputs, dynamo=True)
        else:
            model2 = copy.deepcopy(model)
            timings[rep, 1] = timed(model2, example_inputs, dynamo=False)
    median = np.median(timings, axis=0)
    print(f"Eager Latency: {median[0]} ms")
    if dynamo:
        print(f"TorchDynamo Eager latency: {median[1]} ms")
        print(f"speedup: {median[0]/median[1]} ")
    else:
        print(f"Eager latency r2: {median[1]} ms")
        print(f"speedup: {median[0]/median[1]} ")

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--torchdynamo", action="store_true", help="load torchdynamo library")
    parser.add_argument("--device", choices=["cpu", "cuda"], default="cuda", help="specify device")
    args = parser.parse_args()
    DEVICE = args.device
    if args.torchdynamo:
        import torchdynamo
        optimize_ctx = torchdynamo.optimize("eager")
        with optimize_ctx:
            pass
    Model = load_model_by_name("alexnet")
    m = Model(device=DEVICE, test="eval", jit=False)
    model, example_inputs = m.get_module()
    speedup_experiment(model, example_inputs, dynamo=args.torchdynamo)
jansel commented 2 years ago

For the second experiment I was suggesting not using TorchDynamo at all (just interleaving unrelated things both in eager mode). My hypothesis is this is a function of the measurement harness. Performance within the measured region is correlated to what happens outside the measured region.

xuzhao9 commented 2 years ago

For the second experiment I was suggesting not using TorchDynamo at all (just interleaving unrelated things both in eager mode). My hypothesis is this is a function of the measurement harness. Performance within the measured region is correlated to what happens outside the measured region.

Yes that was what I did in the second experiment (see the updated code for details).

jansel commented 2 years ago

Ah perfect, thanks! Just wanted to confirm.

I'm drawing a blank on what TorchDynamo could be doing to effect GPU-only (and only A100, not RTX 3090). TorchDynamo doesn't do any GPU work, it just runs the captured graph eagerly -- which should be literally identical kernels.

Perhaps there is some way to clear GPU caches...

xuzhao9 commented 2 years ago

@jansel I updated the benchmarking script and is able to reproduce the slowdown within the same process:

  1. do not use interleave runs (run PyTorch eager first, then run TorchDynamo eager)
  2. only import torchdynamo on demand (i.e., in the PyTorch eager run, do not import torchdynamo)

Result on V100 (pinned to the highest GPU and GPU memory frequency), run it a few times and looks pretty stable:

$ python runx.py --torchdynamo
Eager Latency: 1.2967605 ms
TorchDynamo Eager latency: 1.598749 ms
speedup: 0.8111094987393268

So the TorchDyamo eager backend is only 81% as fast as PyTorch eager mode. It looks like a single line import torchdynamo slows down both PyTorch eager and TorchDynamo eager. Did we do anything special in import torchdynamo such that it slows down the entire GPU?

Updated script:

import torch
import time
import gc
import copy
import sys
import numpy as np
from torchbenchmark import load_model_by_name
import argparse

DEVICE = "cuda"

def synchronize():
    if DEVICE == "cuda":
        torch.cuda.synchronize()

def timed(model, example_inputs, times=1, dynamo=False):
    torch.manual_seed(1337)
    if dynamo:
        import torchdynamo
    gc.collect()
    synchronize()
    t0 = time.time_ns()
    # Dont collect outputs to correctly measure timing
    if dynamo:
        # with torchdynamo.run():
        #     result = model(*example_inputs)
        result = model(*example_inputs)
    else:
        result = model(*example_inputs)
    synchronize()
    t1 = time.time_ns()
    return (t1 - t0) / 1_000_000

def speedup_experiment(model, example_inputs, model2, example_inputs2, dynamo=False):
    repeat = 100
    timings = np.zeros((repeat, 2), np.float64)
    for rep in range(repeat):
        # interleave the runs to handle frequency scaling and load changes
        timings[rep, 0] = timed(model, example_inputs)
        if dynamo:
            timings[rep, 1] = timed(model2, example_inputs2, dynamo=True)
        else:
            timings[rep, 1] = timed(model2, example_inputs2, dynamo=False)
    median = np.median(timings, axis=0)
    print(f"Eager Latency: {median[0]} ms")
    if dynamo:
        print(f"TorchDynamo Eager latency: {median[1]} ms")
        print(f"speedup: {median[0]/median[1]} ")
    else:
        print(f"Eager latency r2: {median[1]} ms")
        print(f"speedup: {median[0]/median[1]} ")

def speedup_experiment2(model, example_inputs, model2, example_inputs2, dynamo=False):
    repeat = 100
    timings = np.zeros((repeat, 2), np.float64)
    for rep in range(repeat):
        timings[rep, 0] = timed(model, example_inputs)
    for rep in range(repeat):
        if dynamo:
            timings[rep, 1] = timed(model2, example_inputs2, dynamo=True)
        else:
            timings[rep, 1] = timed(model2, example_inputs2, dynamo=False)
    median = np.median(timings, axis=0)
    print(f"Eager Latency: {median[0]} ms")
    if dynamo:
        print(f"TorchDynamo Eager latency: {median[1]} ms")
        print(f"speedup: {median[0]/median[1]} ")
    else:
        print(f"Eager latency r2: {median[1]} ms")
        print(f"speedup: {median[0]/median[1]} ")

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--torchdynamo", action="store_true", help="load torchdynamo library")
    parser.add_argument("--device", choices=["cpu", "cuda"], default="cuda", help="specify device")
    args = parser.parse_args()
    DEVICE = args.device
    Model = load_model_by_name("alexnet")
    m = Model(device=DEVICE, test="eval", jit=False)
    m2 = Model(device=DEVICE, test="eval", jit=False)
    model, example_inputs = m.get_module()
    model2, example_inputs2 = m2.get_module()
    # speedup_experiment(model, example_inputs, model2, example_inputs2, dynamo=args.torchdynamo)
    speedup_experiment2(model, example_inputs, model2, example_inputs2, dynamo=args.torchdynamo)
jansel commented 2 years ago

@xuzhao9 I ran your latest script on an RTX 3090, and I can't reproduce your results:

$ python test2.py 
Eager Latency: 1.2188915 ms
Eager latency r2: 1.2218585 ms
speedup: 0.9975717319149476 
$ python test2.py --torchdynamo
Eager Latency: 1.2186050000000002 ms
TorchDynamo Eager latency: 1.246474 ms
speedup: 0.9776417317970532 

So somehow it is specific to your machine. Perhaps we should try on a V100, with the latest version of CUDA, or with newer drivers.

Are you sure it is just the import torchdynamo line? Does it still slow things down if you remove the torchdynamo.optimize("eager") and torchdynamo.run()?

xuzhao9 commented 2 years ago

@xuzhao9 I ran your latest script on an RTX 3090, and I can't reproduce your results:

$ python test2.py 
Eager Latency: 1.2188915 ms
Eager latency r2: 1.2218585 ms
speedup: 0.9975717319149476 
$ python test2.py --torchdynamo
Eager Latency: 1.2186050000000002 ms
TorchDynamo Eager latency: 1.246474 ms
speedup: 0.9776417317970532 

So somehow it is specific to your machine. Perhaps we should try on a V100, with the latest version of CUDA, or with newer drivers.

Yes it looks like there is a problem within my dev environment. For A100, I am using Nvidia 450.119.03, CUDA 11.1, PyTorch 1.12.0.dev20220408+cu113 (I can't update Nvidia/CUDA version or pin frequency due to lack of sudo permission) For V100, I am using Nvidia 470.82.01, CUDA 11.3, PyTorch 1.12.0.dev20220331+cu113.

Are you sure it is just the import torchdynamo line? Does it still slow things down if you remove the torchdynamo.optimize("eager") and torchdynamo.run()?

Yes, I confirm if I remove torchdynamo.optimize("eager") and torchdynamo.run(), the slowdown still exists.

jansel commented 2 years ago

I am using:

Just import torchdynamo shouldn't have side effects. It doesn't even install the eval_frame handler. The walk of the torch.* module hierarchy in allowed_functions happens lazily on first use, and not import time.

Could you try commenting out import lines in __init__.py to try to narrow down which import exactly is causing the issue?

That is super odd. Perhaps something to do with memory allocation alignment... dunno, drawing a blank.

frank-wei commented 2 years ago

Tried on A100. NVIDIA-SMI 470.57.02 Driver Version: 470.57.02 CUDA Version: 11.4

for X in `seq 10`; do python runx3.py --torchdynamo; done
Eager Latency: 1.701516 ms
TorchDynamo Eager latency: 1.837139 ms
speedup: 0.9261770611804551 
Eager Latency: 1.5440925 ms
TorchDynamo Eager latency: 1.8713935 ms
speedup: 0.8251030582290684 
Eager Latency: 1.7131965 ms
TorchDynamo Eager latency: 1.9872285 ms
speedup: 0.8621034269587015 
Eager Latency: 1.542824 ms
TorchDynamo Eager latency: 1.9405899999999998 ms
speedup: 0.7950283161306614 
Eager Latency: 1.5197295 ms
TorchDynamo Eager latency: 1.858571 ms
speedup: 0.8176870832483666 
Eager Latency: 1.5287305 ms
TorchDynamo Eager latency: 1.8407814999999998 ms
speedup: 0.8304790655490617 
Eager Latency: 1.800862 ms
TorchDynamo Eager latency: 1.9163865 ms
speedup: 0.9397175361024511 
Eager Latency: 1.491285 ms
TorchDynamo Eager latency: 1.8125605 ms
speedup: 0.8227504681912686 
Eager Latency: 1.534363 ms
TorchDynamo Eager latency: 1.8803305 ms
speedup: 0.8160070796064841 
Eager Latency: 1.8389465 ms
TorchDynamo Eager latency: 1.9739385 ms
speedup: 0.9316128643318928 

After commented some imports in __init__.py as follows:

from . import convert_frame
from . import resume_execution
# from .eval_frame import disable
from .eval_frame import optimize
# from .eval_frame import optimize_assert
# from .eval_frame import reset_code
# from .eval_frame import run

and all the imports in fx2trt

def fx2trt(subgraph, **kwargs):
    if subgraph.will_tensorrt_barf():
        # TensorRT fails violently with an abort() on this
        return None

    # import fx2trt_oss.tracer.acc_tracer.acc_tracer as acc_tracer
    # from fx2trt_oss.fx.fx2trt import InputTensorSpec
    # from fx2trt_oss.fx.fx2trt import TRTInterpreter
    # from fx2trt_oss.fx.tools.trt_splitter import TRTSplitter
    # from fx2trt_oss.fx.tools.trt_splitter import TRTSplitterSetting
    # from fx2trt_oss.fx.trt_module import TRTModule
    # from fx2trt_oss.fx.utils import LowerPrecision

The speedup is still hampered.

 for X in `seq 3`; do python runx3.py --torchdynamo; done
Eager Latency: 1.5356960000000002 ms
TorchDynamo Eager latency: 2.0057774999999998 ms
speedup: 0.7656362682301503 
Eager Latency: 1.738222 ms
TorchDynamo Eager latency: 1.8921785 ms
speedup: 0.9186353190251342 
Eager Latency: 1.5379595 ms
TorchDynamo Eager latency: 1.8584245 ms
speedup: 0.8275609259348443
jansel commented 2 years ago

So what is the import line causing the issue?

frank-wei commented 2 years ago

So what is the import line causing the issue?

no conclusion yet.

xuzhao9 commented 2 years ago

Thanks to @jansel, I can confirm this line is the root cause. https://github.com/jansel/torchdynamo/blob/bf90b8cdbacf35944fa8c12185b1823dc5cb90bb/torchdynamo/skipfiles.py#L123

Specifically, adding five packages slows down the test: "networkx", "omegaconf", "onnx", "pandas", "sklearn", and of which sklearn slows down 10%, and the other four slow 8-9% combined. This explains the total ~20% slowdown.

jansel commented 2 years ago

Interesting, that could explain why I can't reproduce the issue. I might have different versions of those 6 packages.

It is a bit alarming that importing random dependencies slows down PyTorch. Many of those packages don't even use the GPU.

jansel commented 2 years ago

A few theories from internal chat thread:

xuzhao9 commented 2 years ago

Another finding: the slowdown of import sklearn is related to gc.collect() at https://github.com/facebookresearch/torchdynamo/blob/main/torchbench.py#L165. After removing this line, the regression from import sklearn is gone.

ngimel commented 2 years ago

Can you check if you'd still see the slowdown if you leave the gc.collect() line in place, but add a couple untimed dry runs after this, before starting to collect actual timings?

malfet commented 2 years ago

At least part of the regression is due to the fact that PyTorch is build using Intel's OpenMP, while sklearn relies on GNU OpenMP runtime and it passes some environment variables to tell MKL to interop with it. Replacing import sklearn with import _openmp_helpers(See https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/utils/_openmp_helpers.pyx) results in 9% slowdown, but the module does nothing but imports libgomp.so

xuzhao9 commented 2 years ago

At least part of the regression is due to the fact that PyTorch is build using Intel's OpenMP, while sklearn relies on GNU OpenMP runtime and it passes some environment variables to tell MKL to interop with it. Replacing import sklearn with import _openmp_helpers(See https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/utils/_openmp_helpers.pyx) results in 9% slowdown, but the module does nothing but imports libgomp.so

This is a bit strange because I can only reproduce the issue on GPU, not CPU. How importing libgomp.so impact CUDA performance?

sanchitintel commented 2 years ago

Hi @malfet & @xuzhao9,

The public PyTorch release for linux is not built with Intel OpenMP.

malfet commented 2 years ago

The public PyTorch release for linux is not built with Intel OpenMP.

Hmm, that's a good point, it's only linked with iomp5 on MacOS, but not on Linux:

$ python -c "import torch;print(open('/proc/self/maps').read())"|grep omp
7f31e851d000-7f31e8527000 r--p 00000000 917:b543c 144115217301836474     /fsx/users/nshulga/conda/envs/py_3.9-torch-1.11-cuda-11.3/lib/libgomp.so.1.0.0
malfet commented 2 years ago

If one to increase batch size, then discrepancy between eager mode with and without sklearn import becomes negligible, using following script derived from one @xuzhao9 posted:

import torch
import time
import gc
import sys
import os
import numpy as np
import torchvision.models as models
import argparse

DEVICE = "cuda"

def synchronize():
    if DEVICE == "cuda":
        torch.cuda.synchronize()

def timed(model, example_inputs, times=1, dynamo=False):
    torch.manual_seed(1337)
    if dynamo:
        import sklearn
    gc.collect()
    synchronize()
    t0 = time.time_ns()
    # Dont collect outputs to correctly measure timing
    result = model(*example_inputs)
    synchronize()
    t1 = time.time_ns()
    return (t1 - t0) / 1_000_000

def speedup_experiment(model, example_inputs, model2, example_inputs2):
    repeat = 100
    timings = np.zeros((repeat, 2), np.float64)
    for rep in range(repeat):
        timings[rep, 0] = timed(model, example_inputs)
    for rep in range(repeat):
        timings[rep, 1] = timed(model2, example_inputs2, dynamo=True)
    median = np.median(timings, axis=0)
    print(f"Eager Latency: {median[0]} ms")
    print(f"sklearn Eager latency: {median[1]} ms")
    print(f"speedup: {median[0]/median[1]} ")

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--device", choices=["cpu", "cuda"], default="cuda", help="specify device")
    parser.add_argument("--dtype", choices=["float16", "float32"], default="float16", help="specify dtype")
    parser.add_argument("--sklearn", action="store_true", help="load torchdynamo library")
    parser.add_argument("--batch-size", type=int,  default=8, help="specify device")
    args = parser.parse_args()
    DEVICE = args.device
    DTYPE = {"float16": torch.float16, "float32": torch.float32}[args.dtype]
    batch_size = args.batch_size
    model = models.alexnet(pretrained=True).to(device=DEVICE, dtype=DTYPE)
    example_inputs = (torch.randn(batch_size, 3, 224, 224).to(device=DEVICE, dtype=DTYPE), )
    model2 = models.alexnet(pretrained=True).to(device=DEVICE, dtype=DTYPE)
    example_inputs2 = (torch.randn(batch_size, 3, 224, 224).to(device=DEVICE, dtype=DTYPE), )
    model.eval()
    model2.eval()
    if args.sklearn:
        import sklearn
    speedup_experiment(model, example_inputs, model2, example_inputs2)
(py_3.9-torch-1.11-cuda-11.3) nshulga@dev-st-p3dn24xlarge-1:~$ python runx.py  --batch-size=8
Eager Latency: 1.3937244999999998 ms
sklearn Eager latency: 1.569018 ms
speedup: 0.8882782096827441 
(py_3.9-torch-1.11-cuda-11.3) nshulga@dev-st-p3dn24xlarge-1:~$ python runx.py  --batch-size=16
Eager Latency: 1.3827145 ms
sklearn Eager latency: 1.571698 ms
speedup: 0.8797583886980832 
(py_3.9-torch-1.11-cuda-11.3) nshulga@dev-st-p3dn24xlarge-1:~$ python runx.py  --batch-size=32
Eager Latency: 2.0108475 ms
sklearn Eager latency: 2.0631779999999997 ms
speedup: 0.9746359742106596 
(py_3.9-torch-1.11-cuda-11.3) nshulga@dev-st-p3dn24xlarge-1:~$ python runx.py  --batch-size=64
Eager Latency: 3.54023 ms
sklearn Eager latency: 3.5917225000000004 ms
speedup: 0.9856635639306767 
malfet commented 2 years ago

Hmm, another strange fact: why cudaStreamIsCapturing dominates CUDA API time (even though we don't do any graph tracing):

$ nvprof python runx.py 
==3411== NVPROF is profiling process 3411, command: python runx.py
Eager Latency: 1.8098640000000001 ms
sklearn Eager latency: 2.1530690000000003 ms
speedup: 0.8405973055206312 
==3411== Profiling application: python runx.py
==3411== Profiling result:
            Type  Time(%)      Time     Calls       Avg       Min       Max  Name
 GPU activities:   24.31%  51.150ms        34  1.5044ms  1.5990us  16.041ms  [CUDA memcpy HtoD]
                   21.30%  44.810ms       400  112.03us  56.480us  175.17us  void cutlass::Kernel<cutlass_70_wmma_tensorop_f16_s161616gemm_f16_16x16_64x2_tn_align8>(cutlass_70_wmma_tensorop_f16_s161616gemm_f16_16x16_64x2_tn_align8Params)
                   12.56%  26.437ms       200  132.19us  131.01us  133.38us  volta_fp16_scudnn_fp16_128x64_relu_xregs_large_nn_v1
                   10.39%  21.866ms       400  54.664us  49.727us  59.999us  void xmma_cudnn::gemm::kernel<xmma_cudnn::implicit_gemm::fprop_indexed::Kernel_traits<xmma_cudnn::Volta_hmma_fp32_traits, xmma_cudnn::Cta_tile<xmma_cudnn::Volta<int=0>, int=64, int=128, int=32, int=2, int=2, int=1, int=1>, xmma_cudnn::implicit_gemm::fprop_indexed::Gmem_tile_a_t<xmma_cudnn::Volta_hmma_fp32_traits, xmma_cudnn::Cta_tile<xmma_cudnn::Volta<int=0>, int=64, int=128, int=32, int=2, int=2, int=1, int=1>, xmma_cudnn::implicit_gemm::Input_related<int=0, int=0, int=0, bool=0>, int=16, bool=0, xmma_cudnn::implicit_gemm::fprop_indexed::Gmem_tile_base_a<xmma_cudnn::Volta_hmma_fp32_traits, xmma_cudnn::Cta_tile<xmma_cudnn::Volta<int=0>, int=64, int=128, int=32, int=2, int=2, int=1, int=1>, xmma_cudnn::implicit_gemm::Input_related<int=0, int=0, int=0, bool=0>, int=16, xmma_cudnn::Row, int=32, int=64>>, xmma_cudnn::implicit_gemm::fprop_indexed::Gmem_tile_c_t<xmma_cudnn::Volta_hmma_fp32_traits, xmma_cudnn::Cta_tile<xmma_cudnn::Volta<int=0>, int=64, int=128, int=32, int=2, int=2, int=1, int=1>, int=16, xmma_cudnn::Fragment_c<xmma_cudnn::Volta_hmma_fp32_traits, xmma_cudnn::Cta_tile<xmma_cudnn::Volta<int=0>, int=64, int=128, int=32, int=2, int=2, int=1, int=1>, bool=0>>, xmma_cudnn::implicit_gemm::Input_related<int=0, int=0, int=0, bool=0>, int=1>>(xmma_cudnn::Volta_hmma_fp32_traitsParams)
                    6.50%  13.684ms       200  68.418us  67.647us  69.376us  volta_fp16_s884cudnn_fp16_256x64_ldg8_relu_f2f_exp_small_nhwc2nchw_tn_v1
                    5.23%  11.009ms       200  55.042us  49.920us  57.279us  sm70_xmma_fprop_implicit_gemm_f16f16_f16f32_f32_nhwckrsc_nhwc_tilesize64x64x64_stage1_warpsize2x2x1_g1_tensor8x8x4_kernel
                    4.78%  10.057ms      1600  6.2850us  2.5280us  13.376us  void cudnn::ops::nchwToNhwcKernel<__half, __half, float, bool=0, bool=1, cudnnKernelDataType_t=0>(cudnn::ops::nchw2nhwc_params_t<float>, __half const *, __half*)
                    3.20%  6.7238ms      1000  6.7230us  4.2880us  12.064us  _ZN2at6native18elementwise_kernelILi128ELi4EZNS0_15gpu_kernel_implINS0_13BinaryFunctorIN3c104HalfES5_S5_NS0_10AddFunctorIfEEEEEEvRNS_18TensorIteratorBaseERKT_EUliE_EEviT1_
                    2.52%  5.3069ms       600  8.8440us  4.1280us  13.760us  void at::native::_GLOBAL__N__63_tmpxft_0000631e_00000000_21_DilatedMaxPool2d_compute_86_cpp1_ii_6258b574::max_pool_forward_nchw<c10::Half, c10::Half>(int, c10::Half const *, int, int, int, int, int, int, int, int, int, int, int, int, int, at::native::_GLOBAL__N__63_tmpxft_0000631e_00000000_21_DilatedMaxPool2d_compute_86_cpp1_ii_6258b574::max_pool_forward_nchw<c10::Half, c10::Half>*, long*)
                    2.31%  4.8502ms      1400  3.4640us  1.6640us  8.8000us  _ZN2at6native29vectorized_elementwise_kernelILi4EZZZNS0_84_GLOBAL__N__60_tmpxft_00007377_00000000_21_TensorCompare_compute_86_cpp1_ii_d0af11f719launch_clamp_scalarERNS_18TensorIteratorBaseEN3c106ScalarES6_NS0_6detail11ClampLimitsEENKUlvE_clEvENKUlvE14_clEvEUlNS5_4HalfEE_NS_6detail5ArrayIPcLi2EEEEEviT0_T1_
                    2.03%  4.2620ms       200  21.309us  20.544us  22.496us  void cutlass::Kernel<cutlass_70_wmma_tensorop_s161616gemm_f16_32x32_64x2_tn_align8>(cutlass_70_wmma_tensorop_s161616gemm_f16_32x32_64x2_tn_align8Params)
                    1.54%  3.2403ms       200  16.201us  15.904us  16.576us  void at::native::_GLOBAL__N__69_tmpxft_00005dd7_00000000_21_AdaptiveAveragePooling_compute_86_cpp1_ii_ef17039c::adaptive_average_pool<c10::Half>(c10::Half*, c10::Half, int, int, int, int, long, long, long)
                    1.11%  2.3282ms       600  3.8800us  3.0710us  5.5360us  _ZN2at6native18elementwise_kernelILi128ELi4EZNS0_15gpu_kernel_implIZZZNS0_75_GLOBAL__N__51_tmpxft_000061c0_00000000_21_Copy_compute_86_cpp1_ii_ddc55cb923direct_copy_kernel_cudaERNS_18TensorIteratorBaseEENKUlvE0_clEvENKUlvE18_clEvEUlN3c104HalfEE_EEvS5_RKT_EUliE_EEviT1_
                    1.06%  2.2252ms       600  3.7080us  3.0400us  5.1840us  void cudnn::ops::nhwcToNchwKernel<__half, __half, float, bool=1, bool=0, cudnnKernelDataType_t=0>(cudnn::ops::nhwc2nchw_params_t<float>, __half const *, __half*)
                    0.47%  990.46us       604  1.6390us  1.5680us  2.3360us  [CUDA memset]
                    0.36%  753.02us       400  1.8820us  1.6950us  2.5920us  void cask_cudnn::computeOffsetsKernel<bool=0, bool=0>(cask_cudnn::ComputeOffsetsParams)
                    0.35%  726.78us       200  3.6330us  3.1680us  4.4800us  void splitKreduce_kernel<float, __half, float, __half, bool=1, bool=0>(cublasSplitKParams<float>, float const *, __half const *, __half*, float const *, float const *, __half const *, void*, long, float*, int*)
      API calls:   80.27%  5.01770s       210  23.894ms  1.0040us  5.01737s  cudaStreamIsCapturing
                   16.75%  1.04700s         4  261.75ms  1.6460us  653.52ms  cudaFree
                    1.14%  71.470ms      8200  8.7150us  5.0050us  493.74us  cudaLaunchKernel
                    0.86%  53.921ms        34  1.5859ms  7.1320us  16.178ms  cudaMemcpyAsync
                    0.26%  16.008ms     39525     405ns     268ns  499.95us  cudaGetDevice
                    0.21%  13.376ms       400  33.440us  5.1250us  92.660us  cudaDeviceSynchronize
                    0.10%  6.2973ms        20  314.87us  3.9500us  762.77us  cudaMalloc
                    0.09%  5.5990ms       604  9.2690us  6.0650us  87.593us  cudaMemsetAsync
                    0.05%  3.1450ms      1200  2.6200us  1.2400us  239.89us  cudaEventRecord
                    0.04%  2.6237ms       297  8.8340us     137ns  421.69us  cuDeviceGetAttribute
                    0.04%  2.4903ms         3  830.10us  798.66us  877.06us  cudaGetDeviceProperties
                    0.03%  2.0540ms         3  684.65us  664.22us  698.85us  cuDeviceTotalMem
                    0.03%  2.0162ms      1200  1.6800us     432ns  22.153us  cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags
                    0.03%  1.7191ms      8010     214ns     120ns  16.468us  cudaGetLastError
                    0.02%  1.4726ms        34  43.310us  3.5840us  100.27us  cudaStreamSynchronize
                    0.02%  1.1446ms         1  1.1446ms  1.1446ms  1.1446ms  cudaHostAlloc
                    0.02%  1.0711ms       200  5.3550us  4.5650us  17.565us  cudaEventQuery
                    0.01%  613.70us       800     767ns     308ns  3.3600us  cudaStreamGetCaptureInfo
                    0.01%  379.60us         4  94.898us  1.9380us  371.47us  cudaStreamCreateWithPriority
                    0.01%  315.70us         3  105.23us  101.60us  111.98us  cuDeviceGetName
                    0.00%  117.93us       203     580ns     174ns  1.1070us  cuDevicePrimaryCtxGetState
                    0.00%  52.471us         8  6.5580us  1.9580us  35.341us  cudaStreamCreateWithFlags
                    0.00%  49.092us        48  1.0220us     440ns  12.521us  cudaEventCreateWithFlags
                    0.00%  46.053us        48     959ns     603ns  3.6960us  cudaFuncSetAttribute
                    0.00%  17.310us        32     540ns     277ns  3.3390us  cudaDeviceGetAttribute
                    0.00%  5.4580us         1  5.4580us  5.4580us  5.4580us  cuDeviceGetPCIBusId
                    0.00%  5.3510us         2  2.6750us  2.3960us  2.9550us  cuInit
                    0.00%  3.2920us         5     658ns     172ns  2.1900us  cuDeviceGetCount
                    0.00%  3.0950us         1  3.0950us  3.0950us  3.0950us  cudaGetSymbolAddress
                    0.00%  3.0090us         4     752ns     214ns  2.2100us  cudaGetDeviceCount
                    0.00%  2.0980us         1  2.0980us  2.0980us  2.0980us  cudaDriverGetVersion
                    0.00%  2.0660us         1  2.0660us  2.0660us  2.0660us  cudaHostGetDevicePointer
                    0.00%  1.4250us         1  1.4250us  1.4250us  1.4250us  cudaDeviceGetStreamPriorityRange
                    0.00%  1.2240us         4     306ns     140ns     615ns  cuDeviceGet
                    0.00%     832ns         3     277ns     243ns     310ns  cuDeviceGetUuid
                    0.00%     715ns         2     357ns     282ns     433ns  cuDriverGetVersion
malfet commented 2 years ago

Interestingly enough, this benchmark is much faster using CUDA-10.2 (and discrepancy is much smaller, because cudaStreamIsCapturing is CUDA-11.x+ only API call):

$ nvprof python runx.py 
==5631== NVPROF is profiling process 5631, command: python runx.py
Eager Latency: 1.2893634999999999 ms
sklearn Eager latency: 1.362261 ms
speedup: 0.9464878609899278 
==5631== Profiling application: python runx.py
==5631== Profiling result:
            Type  Time(%)      Time     Calls       Avg       Min       Max  Name
 GPU activities:   30.20%  72.399ms       600  120.66us  18.368us  274.33us  volta_fp16_s884gemm_fp16_128x64_ldg8_f2f_tn
                   21.22%  50.865ms        36  1.4129ms  1.5360us  15.958ms  [CUDA memcpy HtoD]
                   10.96%  26.274ms       200  131.37us  129.98us  132.64us  volta_fp16_scudnn_fp16_128x64_relu_xregs_large_nn_v1
                    8.38%  20.080ms       400  50.199us  48.543us  52.703us  Volta_hmma_implicit_gemm_fprop_fp32_nhwc_64x32x64x1_1x3x3x0x1
                    6.52%  15.640ms       200  78.198us  77.280us  79.264us  Volta_hmma_implicit_gemm_fprop_fp32_nhwc_128x128x64x1_1x3x3x0x1
                    5.99%  14.355ms       200  71.776us  70.687us  73.023us  Volta_hmma_implicit_gemm_fprop_fp32_nhwc_256x64x64x1_1x5x5x0x1
                    4.38%  10.512ms      1600  6.5690us  2.6880us  13.824us  void nchwToNhwcKernel<__half, __half, float, bool=1, bool=0>(int, int, int, int, __half const *, __half*, float, float)
                    2.87%  6.8697ms      1000  6.8690us  4.3520us  12.064us  _ZN2at6native18elementwise_kernelILi128ELi4EZNS0_15gpu_kernel_implINS0_13BinaryFunctorIN3c104HalfES5_S5_NS0_10AddFunctorIfEEEEEEvRNS_18TensorIteratorBaseERKT_EUliE_EEviT1_
                    2.28%  5.4710ms       600  9.1180us  3.9680us  14.240us  void at::native::_GLOBAL__N__63_tmpxft_000043ff_00000000_16_DilatedMaxPool2d_compute_75_cpp1_ii_6258b574::max_pool_forward_nchw<c10::Half, c10::Half>(int, c10::Half const *, int, int, int, int, int, int, int, int, int, int, int, int, int, at::native::_GLOBAL__N__63_tmpxft_000043ff_00000000_16_DilatedMaxPool2d_compute_75_cpp1_ii_6258b574::max_pool_forward_nchw<c10::Half, c10::Half>*, long*)
                    2.08%  4.9898ms      1400  3.5640us  1.7600us  11.616us  _ZN2at6native29vectorized_elementwise_kernelILi4EZZZNS0_84_GLOBAL__N__60_tmpxft_0000507a_00000000_16_TensorCompare_compute_75_cpp1_ii_d0af11f719launch_clamp_scalarERNS_18TensorIteratorBaseEN3c106ScalarES6_NS0_6detail11ClampLimitsEENKUlvE_clEvENKUlvE14_clEvEUlNS5_4HalfEE_NS_6detail5ArrayIPcLi2EEEEEviT0_T1_
                    1.47%  3.5253ms       800  4.4060us  3.0400us  9.1840us  void nhwcToNchwKernel<__half, __half, float, bool=1, bool=0>(int, int, int, int, __half const *, __half*, float, float)
                    1.40%  3.3613ms       200  16.806us  16.479us  17.472us  void at::native::_GLOBAL__N__69_tmpxft_00003fb1_00000000_16_AdaptiveAveragePooling_compute_75_cpp1_ii_ef17039c::adaptive_average_pool<c10::Half>(c10::Half*, c10::Half, int, int, int, int, long, long, long)
                    1.00%  2.4039ms       600  4.0060us  2.7830us  12.800us  void splitKreduce_kernel<__half, __half, float>(cublasSplitKParams<float>, __half const *, __half const *, __half*, float const *, float const *)
                    0.96%  2.2979ms       600  3.8290us  3.1030us  8.6080us  _ZN2at6native18elementwise_kernelILi128ELi4EZNS0_15gpu_kernel_implIZZZNS0_75_GLOBAL__N__51_tmpxft_00004315_00000000_16_Copy_compute_75_cpp1_ii_ddc55cb923direct_copy_kernel_cudaERNS_18TensorIteratorBaseEENKUlvE0_clEvENKUlvE18_clEvEUlN3c104HalfEE_EEvS5_RKT_EUliE_EEviT1_
                    0.15%  361.76us       200  1.8080us  1.4720us  2.1120us  cudnn::gemm::computeOffsetsKernel(cudnn::gemm::ComputeOffsetsParams)
                    0.14%  327.17us       204  1.6030us  1.5680us  2.2720us  [CUDA memset]
      API calls:   87.89%  4.59159s        22  208.71ms  4.3320us  4.58564s  cudaMalloc
                    8.27%  432.16ms         7  61.737ms     827ns  262.46ms  cudaFree
                    1.29%  67.310ms      8600  7.8260us  4.4680us  497.92us  cudaLaunchKernel
                    1.03%  53.634ms        34  1.5775ms  7.8300us  16.102ms  cudaMemcpyAsync
                    0.76%  39.844ms       400  99.610us  7.2250us  217.62us  cudaDeviceSynchronize
                    0.27%  14.231ms     39109     363ns     269ns  16.881us  cudaGetDevice
                    0.10%  5.3753ms       285  18.860us     135ns  2.1885ms  cuDeviceGetAttribute
                    0.07%  3.4358ms      1600  2.1470us  1.0630us  82.579us  cudaEventRecord
                    0.04%  2.3343ms         1  2.3343ms  2.3343ms  2.3343ms  cudaHostAlloc
                    0.04%  2.1773ms       204  10.672us  7.9260us  49.996us  cudaMemsetAsync
                    0.04%  2.1516ms         3  717.20us  667.03us  812.66us  cuDeviceTotalMem
                    0.04%  1.9088ms      9410     202ns     116ns  12.480us  cudaGetLastError
                    0.03%  1.7035ms         4  425.88us  4.8480us  1.6820ms  cudaStreamCreateWithPriority
                    0.03%  1.6660ms       600  2.7760us  1.2370us  18.910us  cudaEventQuery
                    0.03%  1.6257ms         2  812.87us  807.85us  817.89us  cudaGetDeviceProperties
                    0.03%  1.4117ms        34  41.520us  8.0820us  101.20us  cudaStreamSynchronize
                    0.02%  940.24us       600  1.5670us     836ns  83.960us  cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags
                    0.01%  305.12us         3  101.71us  93.610us  116.02us  cuDeviceGetName
                    0.00%  92.978us       180     516ns     355ns  2.2030us  cudaFuncSetAttribute
                    0.00%  60.980us         8  7.6220us  1.6930us  44.213us  cudaStreamCreateWithFlags
                    0.00%  33.055us        48     688ns     425ns  2.1170us  cudaEventCreateWithFlags
                    0.00%  30.779us         2  15.389us  14.331us  16.448us  cudaMemcpy
                    0.00%  20.600us        40     515ns     304ns  3.3830us  cudaDeviceGetAttribute
                    0.00%  5.2590us         1  5.2590us  5.2590us  5.2590us  cudaHostGetDevicePointer
                    0.00%  5.2450us         2  2.6220us  2.2550us  2.9900us  cuInit
                    0.00%  3.5710us         1  3.5710us  3.5710us  3.5710us  cuDeviceGetPCIBusId
                    0.00%  2.9290us         4     732ns     177ns  2.2950us  cudaGetDeviceCount
                    0.00%  2.2850us         5     457ns     171ns  1.1220us  cuDeviceGetCount
                    0.00%  1.4220us         3     474ns     297ns     804ns  cuDeviceGetUuid
                    0.00%  1.3440us         4     336ns     152ns     564ns  cuDeviceGet
                    0.00%     925ns         2     462ns     429ns     496ns  cuDriverGetVersion
ngimel commented 2 years ago

cudaStreamIsCapturing is probably called during memory allocations that don't hit in cache. We shouldn't be benchmarking runs that have a lot of uncached mallocs - that's probably why removing gc.collect() helps. In the steady state, we expect memory requests to be mostly served from the cached memory, and not call cudaMalloc.

malfet commented 2 years ago

cudaStreamIsCapturing is probably called during memory allocations that don't hit in cache. We shouldn't be benchmarking runs that have a lot of uncached mallocs - that's probably why removing gc.collect() helps. In the steady state, we expect memory requests to be mostly served from the cached memory, and not call cudaMalloc.

It would have been true, if number of cudaMalloc() calls would much number of cudaStreamIsCapturing, but it is 10x different: cudaMalloc is called 20 times, while cudaStreamIsCapturing 200+ times (i.e. at least once for every model(*inputs) call)

xuzhao9 commented 2 years ago

The previous bug has been fixed by https://github.com/facebookresearch/torchdynamo/pull/153. Yet another problem shows up, the reproducing script:

import torch
import time
import gc
import copy
import sys
import numpy as np
from torchbenchmark import load_model_by_name
import argparse

DEVICE = "cuda"

def synchronize():
    if DEVICE == "cuda":
        torch.cuda.synchronize()

def timed(model, example_inputs, times=1, dynamo=None):
    torch.manual_seed(1337)
    if True:
        import torchdynamo
        optimize_ctx = torchdynamo.optimize("eager")
        with optimize_ctx:
            pass
       # print(sys.modules.keys())
    synchronize()
    t0 = time.time_ns()
    # Dont collect outputs to correctly measure timing
    if dynamo == "d1":
        with torchdynamo.optimize("eager"):
            result = model(*example_inputs)
    else:
        with torchdynamo.run():
            result = model(*example_inputs)
    synchronize()
    t1 = time.time_ns()
    return (t1 - t0) / 1_000_000

def speedup_experiment(model, example_inputs, model2, example_inputs2):
    repeat = 100
    timings = np.zeros((repeat, 2), np.float64)
    for rep in range(repeat):
        timings[rep, 0] = timed(model, example_inputs)
    for rep in range(repeat):
        timings[rep, 1] = timed(model2, example_inputs2, dynamo="d1")
    median = np.median(timings, axis=0)
    print(f"Eager Latency: {median[0]} ms")
    print(f"TorchDynamo Eager latency: {median[1]} ms")
    print(f"speedup: {median[0]/median[1]} ")

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--device", choices=["cpu", "cuda"], default="cuda", help="specify device")
    args = parser.parse_args()
    DEVICE = args.device
    Model = load_model_by_name("alexnet")
    m = Model(device=DEVICE, test="eval", jit=False)
    m2 = Model(device=DEVICE, test="eval", jit=False)
    model, example_inputs = m.get_module()
    model2, example_inputs2 = m2.get_module()
    speedup_experiment(model, example_inputs, model2, example_inputs2)

run python runx.py returns:

Eager Latency: 1.0757485 ms
TorchDynamo Eager latency: 1.183398 ms
speedup: 0.9090335626729131

So torchdynamo.optimize("eager") is 10% slower than torchdynamo.run(). Update: the slowdown seems to be caused by https://github.com/facebookresearch/torchdynamo/blob/1e131a6e81cf2253b3672caa1f7a72629a472989/torchdynamo/mutation_guard.py#L88

anijain2305 commented 1 year ago

Closing as stale.