pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.48k stars 479 forks source link

ExecuteTime metric not accurate #4306

Open shunting314 opened 1 year ago

shunting314 commented 1 year ago

🐛 Bug

In https://github.com/pytorch/pytorch/pull/88449 , we update the timed method as follows to report xla metrics:

def timed(model, model_iter_fn, example_inputs, times=1, return_result=False, prompt=""):
    use_xla = tensor_is_on_xla(example_inputs)
    synchronize()

    reset_rng_state()
    if use_xla:
        import torch_xla.core.xla_model as xm

        xm.mark_step()
        xm.wait_device_ops()

    import torch_xla.debug.metrics as met; met.clear_metrics(); met.clear_counters()
    t0 = time.perf_counter()
    # Dont collect outputs to correctly measure timing
    for _ in range(times):
        if use_xla:
            # force a const seed for xla
            xm.set_rng_state(23, str(xm.xla_device()))
        result = model_iter_fn(model, example_inputs, collect_outputs=return_result)

        # instead of calling sync on result_list, we should call mark_step.
        # In training case, result_list may be empty, but we want to sync
        # all the pending computations.
        if use_xla:
            # If the model is on XLA device, it's possible that after running
            # the model, the computation is accumulated but not performed yet.
            # Flush all the accumulated computations to make the time measurement
            # accurate.
            import torch_xla.core.xla_model as xm
            xm.mark_step()

    if use_xla:
        import torch_xla.core.xla_model as xm
        xm.wait_device_ops()
    synchronize()
    t1 = time.perf_counter()
    print(f"{prompt} {t1 - t0} sec, metric report:\n{met.short_metrics_report()}")
    return (t1 - t0, result) if return_result else t1 - t0

We then run inference for resnet18 model as follows:

USE_FAKE_TENSOR=0 GPU_NUM_DEVICES=1 python benchmarks/dynamo/torchbench.py --randomize-input --performance --trace-on-xla --backend=torchxla_trace_once --only resnet18 -n 5

The output is a bit conterintuitive:

rep 4 test 0.003318594768643379 sec, metric report:
Metric: ExecuteTime
  TotalSamples: 1
  Accumulator: 003ms699.568us
  Percentiles: 1%=003ms699.568us; 5%=003ms699.568us; 10%=003ms699.568us; 20%=003ms699.568us; 50%=003ms699.568us; 80%=003ms699.568us; 90%=003ms699.568us; 95%=003ms699.568us; 99%=003ms699.568us
Counter: MarkStep
  Value: 2

The ExecuteTime metric is larger than the wall time. As long as there are no parallel execution on device, this should not happen.

cc @JackCaoG @wconstab

JackCaoG commented 1 year ago

Let me try to repo on TPU and do a profile.

JackCaoG commented 1 year ago

with resnet50 I am able to see that dynamo wall time is slightly smaller than actual execute time which is a bit weird.

# lazy
Counter: CachedCompile
  Value: 1
Metric: ExecuteTime
  TotalSamples: 1
  Accumulator: 002ms225.529us
  Percentiles: 1%=002ms225.529us; 5%=002ms225.529us; 10%=002ms225.529us; 20%=002ms225.529us; 50%=002ms225.529us; 80%=002ms225.529us; 90%=002ms225.529us; 95%=002ms225.529us; 99%=002ms225.529us
Counter: MarkStep
  Value: 2

# dynamo
Metric: ExecuteTime
  TotalSamples: 1
  Accumulator: 002ms725.570us
  Percentiles: 1%=002ms725.570us; 5%=002ms725.570us; 10%=002ms725.570us; 20%=002ms725.570us; 50%=002ms725.570us; 80%=002ms725.570us; 90%=002ms725.570us; 95%=002ms725.570us; 99%=002ms725.570us
Counter: MarkStep
  Value: 4

[[0.00780536 0.00940852]
 [0.00711755 0.00283046]
 [0.00748225 0.0025849 ]
 [0.00743424 0.00274822]
 [0.00720488 0.00266087]]
JackCaoG commented 1 year ago

going back to the resnet18, it is a bit weird since lazy seems to execute 2 graphs which dynamo execute on

Counter: CachedCompile
  Value: 1
Metric: ExecuteTime
  TotalSamples: 2
  Accumulator: 001ms214.719us
  ValueRate: 527ms157.953us / second
  Rate: 867.95 / second
  Percentiles: 1%=464.770us; 5%=464.770us; 10%=464.770us; 20%=464.770us; 50%=749.949us; 80%=749.949us; 90%=749.949us; 95%=749.949us; 99%=749.949us
Counter: MarkStep
  Value: 2

Metric: ExecuteTime
  TotalSamples: 1
  Accumulator: 497.271us
  Percentiles: 1%=497.271us; 5%=497.271us; 10%=497.271us; 20%=497.271us; 50%=497.271us; 80%=497.271us; 90%=497.271us; 95%=497.271us; 99%=497.271us
Counter: MarkStep
  Value: 4

[[0.00359686 0.00144872]
 [0.00296736 0.00089427]
 [0.00231279 0.00081857]
 [0.00264634 0.00083998]
 [0.00223804 0.00089804]]

I think the issue is maybe_mark_step does not wait for the device_ops, after

@@ -404,12 +407,13 @@ def maybe_mark_step(args):
     if args.trace_on_xla:
         import torch_xla.core.xla_model as xm
         xm.mark_step()
+        xm.wait_device_ops()

I started to see

Counter: CachedCompile
  Value: 1
Metric: ExecuteTime
  TotalSamples: 1
  Accumulator: 002ms373.590us
  Percentiles: 1%=002ms373.590us; 5%=002ms373.590us; 10%=002ms373.590us; 20%=002ms373.590us; 50%=002ms373.590us; 80%=002ms373.590us; 90%=002ms373.590us; 95%=002ms373.590us; 99%=002ms373.590us
Counter: MarkStep
  Value: 2

Metric: ExecuteTime
  TotalSamples: 1
  Accumulator: 002ms723.050us
  Percentiles: 1%=002ms723.050us; 5%=002ms723.050us; 10%=002ms723.050us; 20%=002ms723.050us; 50%=002ms723.050us; 80%=002ms723.050us; 90%=002ms723.050us; 95%=002ms723.050us; 99%=002ms723.050us
Counter: MarkStep
  Value: 4

[[0.04986051 0.00305879]
 [0.04829975 0.00251851]
 [0.04834096 0.002427  ]
 [0.04822584 0.00263545]
 [0.04905208 0.00265107]]

which executeTime matches on both lazy and dynamo

JackCaoG commented 1 year ago

However, this time made lazy time significantly longer.. looking into this..

edit: I enabled the IR dump which is why lazy was much slower.. disabling that I saw

Counter: CachedCompile
  Value: 1
Metric: ExecuteTime
  TotalSamples: 1
  Accumulator: 001ms389.420us
  Percentiles: 1%=001ms389.420us; 5%=001ms389.420us; 10%=001ms389.420us; 20%=001ms389.420us; 50%=001ms389.420us; 80%=001ms389.420us; 90%=001ms389.420us; 95%=001ms389.420us; 99%=001ms389.420us
Counter: MarkStep
  Value: 2

Metric: ExecuteTime
  TotalSamples: 1
  Accumulator: 642.530us
  Percentiles: 1%=642.530us; 5%=642.530us; 10%=642.530us; 20%=642.530us; 50%=642.530us; 80%=642.530us; 90%=642.530us; 95%=642.530us; 99%=642.530us
Counter: MarkStep
  Value: 4

[[0.00393327 0.00176672]
 [0.00304804 0.00114577]
 [0.00317163 0.00097577]
 [0.0030343  0.00101768]
 [0.00325616 0.00114755]]
2.768x p=0.00

which is expected.

shunting314 commented 1 year ago

@JackCaoG we have a device sync inside timed function already

def timed(model, model_iter_fn, example_inputs, times=1, return_result=False):
    use_xla = tensor_is_on_xla(example_inputs)
    synchronize()

    reset_rng_state()
    if use_xla:
        import torch_xla.core.xla_model as xm

        xm.mark_step()
        xm.wait_device_ops()
JackCaoG commented 1 year ago

if use_xla

Ah Ok, I guess this has to be with the way I print the metric,

             # need call mark_step to perform the computation
             # on randomize_input. Otherwise the first call using the
             # inputs will incur high penalty then the next one.
             maybe_mark_step(args)

             # interleave the runs to handle frequency scaling and load changes
+            met.clear_all()
             with maybe_mark_profile(p=p, mark="expected"):
                 timings[rep, 0], expected_output = timed(
                     model, model_iter_fn, inputs, return_result=True, times=times,
                 )
-
+            print(met.short_metrics_report())
+            met.clear_all()
             # call mark_step between the 2 calls to make the comparison fair.
             maybe_mark_step(args)

@@ -478,9 +486,13 @@ def speedup_experiment(args, model_iter_fn, model, example_inputs, **kwargs):
                 timings[rep, 1], actual_output = timed(
                     model, frozen_model_iter_fn, inputs, return_result=True, times=times,
                 )
-
+            print(met.short_metrics_report())

I print it outside the timed, hence saw that confusing execution time.

JackCaoG commented 1 year ago

ok next question I need to look into is dynamo and lazy seems to be execute a different graph, for resnet I saw

# lazy
Counter: CachedCompile
  Value: 1
Metric: ExecuteTime
  TotalSamples: 1
  Accumulator: 001ms358.530us
  Percentiles: 1%=001ms358.530us; 5%=001ms358.530us; 10%=001ms358.530us; 20%=001ms358.530us; 50%=001ms358.530us; 80%=001ms358.530us; 90%=001ms358.530us; 95%=001ms358.530us; 99%=001ms358.530us
Counter: MarkStep
  Value: 2

# dynamo
Metric: ExecuteTime
  TotalSamples: 1
  Accumulator: 490.160us
  Percentiles: 1%=490.160us; 5%=490.160us; 10%=490.160us; 20%=490.160us; 50%=490.160us; 80%=490.160us; 90%=490.160us; 95%=490.160us; 99%=490.160us
Counter: MarkStep
  Value: 4
shunting314 commented 1 year ago

@JackCaoG that's weird. For resnet18 inference on GPU, I see ExecuteTime for baseline v.s. test (trace_once) to be: 003ms663.747us v.s. 003ms669.788us . Which makes sense

JackCaoG commented 1 year ago

I was able to dump the profile image

Here is my finding

  1. Lazy tracing does takes much longer than the actual device execution
  2. Actual device execution time of both lazy and dynamo are very similar
  3. The host side metrics ExecuteTime are very different

@will-cromar I thought PjrtComputationClinet::ExecuteCompuation will not return until the actual device execution happened, but that doesn't seem to be the case...

JackCaoG commented 1 year ago

I tried to make lazy runs 3 times and then dynamo run 3 times image

It does seems like PjrtComputationClinet::ExecuteCompuation returns before the actual device execution finished. This is somewhat confusing because ExecuteCompuation suppose to return the buffer representing the result of the computation. My only theory is that PJRT had some optimization that make the device execution also async.

However this makes us timing the actual execution pretty difficult. wait_device_ops is only pytorch/xla host level tricks that try to grab the device lock. Once ExecuteCompuation finished it will release the device lock and unblock wait_device_ops. This will make our benchmark inacurrate.

JackCaoG commented 1 year ago

Ok I think this is some kind of runtime optimization, but I can force the wait of device op by trying to move the result tensor from device to the host.

     if use_xla:
         import torch_xla.core.xla_model as xm
-        xm.wait_device_ops()
+        with xp.Trace("lazy_wait_device_ops", step_num=ttt):
+            xm.wait_device_ops()
+            cpu_result = result.cpu()

image

This is not ideal since we introduce an additional device->host that can not be optimize but at least we can get some correct result.

resnet50
[[0.01062124 0.00613502]
 [0.01002307 0.00627029]
 [0.01055949 0.00611338]
 [0.01015586 0.00596699]
 [0.00977452 0.00627681]
 [0.01043    0.00635236]
 [0.01001907 0.00590376]
 [0.01018825 0.00629277]
 [0.01028329 0.00624821]
 [0.01001173 0.0062656 ]]
1.626x p=0.00
JackCaoG commented 1 year ago

my full diff (a bit messy) in https://gist.github.com/JackCaoG/49c0e0407f7a149763e11ad71285c3dc

you can view my profile by

pip3 uninstall tb-nightly tensorboard tensorflow-estimator tf-estimator-nightly tf-nightly
pip3 install tf-nightly==2.12.0.dev20221029 tbp-nightly==2.9.0a20221029
mkdir profile
cd profile
gsutil cp -r gs://tpu-pytorch/tmp/dynamo_profile/dynamo_tracing/try12/* .
tensorboard --logdir . --port 6016
JackCaoG commented 1 year ago

@will-cromar is working on fixing the ExecteTime metrics. I will do some manual patch to rerun the benchmark on dynamo + inference in the mean time.

shunting314 commented 1 year ago

@JackCaoG does the inaccuracy of 'ExecuteTime' metric affect our overall wall time measurement for baseline and test? Basically, I mean if the device sync works as expected.

JackCaoG commented 1 year ago

the issue here is that wait_device_ops does not really wait for the actual device execution finished. However I think this is only the issue for PJRT, but you are using XRT(since you used TPU_NUM_DEVICE=1). Maybe I can benchmark the XRT tmr and see if above constrain also applied.

To be sure that we are measture the correct wall time, @will-cromar can you share your patch that will make ExecuteTime blocking for PJRT? @shunting314 then you can benchmark using the PJRT:GPU(not code change required, only config change) by PJRT_DEVICE=GPU.

shunting314 commented 1 year ago

since you used TPU_NUM_DEVICE=1 To be clear, I'm using GPU_NUM_DEVICES=1

then you can benchmark using the PJRT:GPU Does this guarantee accurate wall time measurement? Can I do that now or should I wait for will-cromar's patch?

JackCaoG commented 1 year ago

The safest way now are

  1. use PJRT + WIll's patch
  2. move one of the result tensor back to cpu device which will force the code to wait for the actual device execution.

Second option will introduce some noise since moving to cpu is not something that can be optimize by dynamo.

JackCaoG commented 1 year ago

With

diff --git a/third_party/xla_client/pjrt_computation_client.cc b/third_party/xla_client/pjrt_computation_client.cc
index 207c8874..fa847c0d 100755
--- a/third_party/xla_client/pjrt_computation_client.cc
+++ b/third_party/xla_client/pjrt_computation_client.cc
@@ -308,6 +308,7 @@ PjRtComputationClient::ExecuteComputation(
   std::vector<DataPtr> datas;
   datas.reserve(results.size());
   for (auto& result : results) {
+    auto status = result->GetReadyFuture().Await();
     std::unique_ptr<xla::PjRtBuffer> buffer = std::move(result);

     std::shared_ptr<PjRtData> data = std::make_shared<PjRtData>(

I am able to make ExecuteComputation blocking and observe

image

ExecuteTime is also shortner than the overall wall time which is expected

Counter: CachedCompile
  Value: 1
Metric: ExecuteTime
  TotalSamples: 1
  Accumulator: 006ms704.420us
  Percentiles: 1%=006ms704.420us; 5%=006ms704.420us; 10%=006ms704.420us; 20%=006ms704.420us; 50%=006ms704.420us; 80%=006ms704.420us; 90%=006ms704.420us; 95%=006ms704.420us; 99%=006ms704.420us
Counter: MarkStep
  Value: 2

Metric: ExecuteTime
  TotalSamples: 1
  Accumulator: 005ms359.789us
  Percentiles: 1%=005ms359.789us; 5%=005ms359.789us; 10%=005ms359.789us; 20%=005ms359.789us; 50%=005ms359.789us; 80%=005ms359.789us; 90%=005ms359.789us; 95%=005ms359.789us; 99%=005ms359.789us
Counter: MarkStep
  Value: 4

[[0.01109302 0.00659533]
 [0.01055866 0.00632675]
 [0.01007744 0.00603052]
 [0.01111902 0.00623013]
 [0.01043393 0.00610085]
 [0.0101979  0.00594579]
 [0.01035866 0.00624219]
 [0.01081218 0.00677581]
 [0.01064773 0.00653071]
 [0.01023987 0.00623919]]
1.682x p=0.00

I will leave this issue open for Will to fix it properly, but the workaround is enough for me to finish my inference benchmark.

JackCaoG commented 1 year ago

ok weird sometimes I still see that ExecuteTime is longer than wall time. For example

Counter: CachedCompile
  Value: 1
Metric: ExecuteTime
  TotalSamples: 1
  Accumulator: 003ms757.569us
  Percentiles: 1%=003ms757.569us; 5%=003ms757.569us; 10%=003ms757.569us; 20%=003ms757.569us; 50%=003ms757.569us; 80%=003ms757.569us; 90%=003ms757.569us; 95%=003ms757.569us; 99%=003ms757.569us
Counter: MarkStep
  Value: 2

Metric: ExecuteTime
  TotalSamples: 1
  Accumulator: 003ms600.400us
  Percentiles: 1%=003ms600.400us; 5%=003ms600.400us; 10%=003ms600.400us; 20%=003ms600.400us; 50%=003ms600.400us; 80%=003ms600.400us; 90%=003ms600.400us; 95%=003ms600.400us; 99%=003ms600.400us
Counter: MarkStep
  Value: 4

[0.00356732 0.00285901]

However if I check the profile, wait_lazy_ops(which will contribute to wall time) is longer than the ExecuteComputation.

image

My theory is that maybe

  1. C++ timer and python timer has some subtle difference
  2. ExecuteTime timer is recorded when the destructor of the local stack object is called. That might happened after the ExecuteTime returns which will unblock python side timer.

Either way, I think python walltime should be comparable between lazy and dynamo based on the profile we captured.

JackCaoG commented 1 year ago

I think this issue is partilly fixed by https://github.com/pytorch/xla/commit/afe7bb6b6562edcd368055899e17c691824575ac, meaning ExecuteTime should be accurate when using PJRT.

However the wait_device_ops is still not accurate on PJRT. wait_device_ops suppose to be used to block caller until all async operation is finished. This is useful when we want to measure how long an execution takes. wait_device_ops works by trying to acquire the device_lock. In PJRT, the device_lock will be return earlier than the actual device execution finished. I had 2 high level ideas to fix it for PJRT

  1. Introduce a new lock that is passed to the

    returned_future->OnReady([timed](Status unused) mutable { timed.reset(); });

    wait_device_ops can try to grab regular device lock and new PJRT device lock

  2. Somehow make wait_device_ops finds a buffer that is not ready and wait until it get ready. Given that there is at most 1 pending execution at any given time, finding any buffer is not ready should be enough.

@will-cromar any thoughts on this problem?

JackCaoG commented 1 year ago

Current I am using

diff --git a/third_party/xla_client/pjrt_computation_client.cc b/third_party/xla_client/pjrt_computation_client.cc
index a51fbed9..4a242f11 100755
--- a/third_party/xla_client/pjrt_computation_client.cc
+++ b/third_party/xla_client/pjrt_computation_client.cc
@@ -329,7 +329,12 @@ PjRtComputationClient::ExecuteComputation(

   std::vector<DataPtr> datas;
   datas.reserve(results.size());
+  bool waited = false;
   for (auto& result : results) {
+    if (!waited) {
+      auto status = result->GetReadyFuture().Await();
+      waited = true;
+    }
     std::unique_ptr<xla::PjRtBuffer> buffer = std::move(result);

     std::shared_ptr<PjRtData> data = std::make_shared<PjRtData>(

to force the execution to sync buffer for benchmarking purpose. Another option I can think of is that

  1. Let PJRT client to save one result buffer in a global state
  2. Add a API to ComputationClient to let it wait for the saved buffer is ready

This will make sure wait_device_ops wait until buffer actually got ready.