intel / intel-xpu-backend-for-triton

OpenAI Triton backend for Intel® GPUs
MIT License
124 stars 35 forks source link

No perf advantage for torch.compile on examples from pytorch tutorial #1721

Open dvrogozh opened 1 month ago

dvrogozh commented 1 month ago

I am trying pytorch tutorial for torch.compile(): https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html#demonstrating-speedups adopting it for xpu backend by s/cuda/xpu. Using https://github.com/pytorch/pytorch/commit/f063027d5424c6b90588ef0e84e9c21be4ce68ae. Tutorial has performance examples demonstrating torch.compile advantage over eager mode for Nvidia. Unfortunately I don't observe similar benefits for xpu - torch.compile runs with similar speed as eager mode. Are there any optimization currently missing for XPU affecting these tutorials? This occurs for both examples in tutorial: for inference and for training.

Results (inference):

eager eval time 0: 1.468490231
eager eval time 1: 0.016250838
eager eval time 2: 0.015404673
eager eval time 3: 0.01476964
eager eval time 4: 0.014657789
eager eval time 5: 0.014552059
eager eval time 6: 0.014473312
eager eval time 7: 0.014476375
eager eval time 8: 0.014540959
eager eval time 9: 0.014519486
~~~~~~~~~~
compile eval time 0: 30.085278137
compile eval time 1: 0.016572904
compile eval time 2: 0.015478853
compile eval time 3: 0.015368476
compile eval time 4: 0.015215709
compile eval time 5: 0.015356365
compile eval time 6: 0.015324649
compile eval time 7: 0.015410529
compile eval time 8: 0.015309956
compile eval time 9: 0.015434349
~~~~~~~~~~
Traceback (most recent call last):
  File "/home/dvrogozh/examples/torch/tutorials/ex5.py", line 63, in <module>
    assert(speedup > 1)
AssertionError

Script (inference):

import time
import torch

def timed(fn):
    start =  time.time_ns()
    result = fn()
    torch.xpu.synchronize()
    return result, (time.time_ns() - start) / 1000000000

# Generates random input and targets data for the model, where `b` is
# batch size.
def generate_data(b):
    return (
        torch.randn(b, 3, 128, 128).to(torch.float32).xpu(),
        torch.randint(1000, (b,)).xpu(),
    )

N_ITERS = 10

from torchvision.models import densenet121
def init_model():
    return densenet121().to(torch.float32).xpu()

model = init_model()

# Reset since we are using a different mode.
import torch._dynamo
torch._dynamo.reset()

model_opt = torch.compile(model, mode="reduce-overhead")

eager_times = []
for i in range(N_ITERS):
    inp = generate_data(16)[0]
    with torch.no_grad():
        _, eager_time = timed(lambda: model(inp))
    eager_times.append(eager_time)
    print(f"eager eval time {i}: {eager_time}")

print("~" * 10)

compile_times = []
for i in range(N_ITERS):
    inp = generate_data(16)[0]
    with torch.no_grad():
        _, compile_time = timed(lambda: model_opt(inp))
    compile_times.append(compile_time)
    print(f"compile eval time {i}: {compile_time}")
print("~" * 10)

import numpy as np
eager_med = np.median(eager_times)
compile_med = np.median(compile_times)
speedup = eager_med / compile_med
assert(speedup > 1)
print(f"(eval) eager median: {eager_med}, compile median: {compile_med}, speedup: {speedup}x")
print("~" * 10)

Note that I did def timed implementation in tutorial to measure e2e time due to https://github.com/pytorch/pytorch/issues/131840. Also note that I did try apply https://github.com/pytorch/pytorch/pull/126456 - this did not change performance results for XPU backend.

alexbaden commented 1 month ago

I am not getting the same results (latest llvm-target branch, LTS driver, and https://github.com/pytorch/pytorch/commit/75f64e12030dfa6f621f1ec2b207892cf8660cdd):

» python ex5.py                                                                                              
eager eval time 0: 1.645833594
eager eval time 1: 0.133093097
eager eval time 2: 0.133950921
eager eval time 3: 0.144729233
eager eval time 4: 0.129245809
eager eval time 5: 0.129086332
eager eval time 6: 0.123269756
eager eval time 7: 0.134237029
eager eval time 8: 0.12613883
eager eval time 9: 0.132319863
~~~~~~~~~~
(I): Detected 2048 spills, recompiling the kernel using large GRF mode
(I): Kernel has now 0 spills
(I): Detected 2048 spills, recompiling the kernel using large GRF mode
(I): Kernel has now 0 spills
(I): Detected 4096 spills, recompiling the kernel using large GRF mode
(I): Kernel has now 0 spills
(I): Detected 4096 spills, recompiling the kernel using large GRF mode
(I): Kernel has now 0 spills
(I): Detected 4096 spills, recompiling the kernel using large GRF mode
(I): Kernel has now 0 spills
(I): Detected 2048 spills, recompiling the kernel using large GRF mode
(I): Kernel has now 0 spills
(I): Detected 1024 spills, recompiling the kernel using large GRF mode
(I): Kernel has now 0 spills
(I): Detected 1024 spills, recompiling the kernel using large GRF mode
(I): Kernel has now 0 spills
(I): Detected 2048 spills, recompiling the kernel using large GRF mode
(I): Kernel has now 0 spills
(I): Detected 2048 spills, recompiling the kernel using large GRF mode
(I): Kernel has now 0 spills
(I): Detected 4096 spills, recompiling the kernel using large GRF mode
(I): Kernel has now 0 spills
(I): Detected 8192 spills, recompiling the kernel using large GRF mode
(I): Kernel has now 0 spills
(I): Detected 8192 spills, recompiling the kernel using large GRF mode
(I): Kernel has now 0 spills
(I): Detected 1024 spills, recompiling the kernel using large GRF mode
(I): Kernel has now 0 spills
(I): Detected 2048 spills, recompiling the kernel using large GRF mode
(I): Kernel has now 0 spills
(I): Detected 2048 spills, recompiling the kernel using large GRF mode
(I): Kernel has now 0 spills
(I): Detected 4096 spills, recompiling the kernel using large GRF mode
(I): Kernel has now 0 spills
(I): Detected 8192 spills, recompiling the kernel using large GRF mode
(I): Kernel has now 0 spills
(I): Detected 8192 spills, recompiling the kernel using large GRF mode
(I): Kernel has now 0 spills
(I): Detected 1024 spills, recompiling the kernel using large GRF mode
(I): Kernel has now 0 spills
(I): Detected 2048 spills, recompiling the kernel using large GRF mode
(I): Kernel has now 0 spills
(I): Detected 2048 spills, recompiling the kernel using large GRF mode
(I): Kernel has now 0 spills
(I): Detected 4096 spills, recompiling the kernel using large GRF mode
(I): Kernel has now 0 spills
(I): Detected 8192 spills, recompiling the kernel using large GRF mode
(I): Kernel has now 0 spills
(I): Detected 8192 spills, recompiling the kernel using large GRF mode
(I): Kernel has now 0 spills
compile eval time 0: 100.98287199
compile eval time 1: 0.0182211
compile eval time 2: 0.015616999
compile eval time 3: 0.01512506
compile eval time 4: 0.015049746
compile eval time 5: 0.015074743
compile eval time 6: 0.015032466
compile eval time 7: 0.015071969
compile eval time 8: 0.015028161
compile eval time 9: 0.014950526
~~~~~~~~~~
(eval) eager median: 0.13270648000000002, compile median: 0.015073356, speedup: 8.804043372955567x

Perhaps there is some logging we can enable to find the difference? Can you try running with TORCH_LOGS="+dynamo" TORCHDYNAMO_VERBOSE=1?

I did not have time to rebuild pytorch now but I can also try that pytorch commit you used, though at first glance mine is much older.

dvrogozh commented 1 month ago

@alexbaden : you did not reproduce my eager mode results, but your torch.compile results are similar to what I have. Your pytorch version is very old and I think eager mode simply falls back to CPU on some aten ops (silently because you are also missing https://github.com/intel/torch-xpu-ops/pull/318). You are missing at least the following torch-xpu-ops updates which implemented a lot of aten ops:

$ git log --oneline 75f64e12030dfa6f621f1ec2b207892cf8660cdd..remotes/origin/main -- third_party/xpu.txt
dfba85c26bf Update torch-xpu-ops pin (ATen XPU implementation) (#131643)
b556d315868 Update torch-xpu-ops pin (ATen XPU implementation) (#131015)
cf090e222ea Update torch-xpu-ops pin (ATen XPU implementation) (#130333)
e98587c58d3 Update torch-xpu-ops pin (ATen XPU implementation) (#129353)

Update fyi: I tried https://github.com/pytorch/pytorch/commit/75f64e12030dfa6f621f1ec2b207892cf8660cdd + PR318. The following eager aten ops fall to cpu: aten::native_batch_norm, aten::max_pool2d_with_indices.out, aten::avg_pool2d.out, aten::_adaptive_avg_pool2d

alexbaden commented 1 month ago

Got it, that makes sense. Let me update PyTorch to latest main and try again.

dvrogozh commented 1 month ago

See #1770 for potential fix.