Closed stephen-youn closed 1 year ago
Different torch.compile
modes may result in different performance results (e.g. torch.compile(model, mode="max-autotune")
).
Also, torch.compile
will generally take longer on the first pass since it needs to compile, but future passes are expected to be faster than baseline.
i tried to run it twice but it was still slower. is there any suggestion to debug this? (e.g., giving particular option in compile, adding option to make trace or verbose outputs and so on).
@stephen-youn Thanks for trying out torch.compile
. PyTorch 2.0 compilers are JIT compiler, i.e., they compile the model on the first iteration. In your script, you are measuring the first iteration latency, and hence you are observing the high latency. I modified your script and observing better numbers on A100 GPU (the numbers are not stable, probably because we are measuring just one iteration, but the speedup is evident).
Script
import torch
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)
opt_model = torch.compile(model, backend="inductor")
# warmup
for _ in range(3):
model(torch.randn(1,3,64,64))
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
model(torch.randn(1,3,64,64))
end_event.record()
torch.cuda.synchronize()
estimate_ms = start_event.elapsed_time(end_event)
print(f"estimated_ms={estimate_ms}")
# warmup
for _ in range(3):
opt_model(torch.randn(1,3,64,64))
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
opt_model(torch.randn(1,3,64,64))
end_event.record()
torch.cuda.synchronize()
estimate_ms = start_event.elapsed_time(end_event)
print(f"estimated_ms={estimate_ms}")
Output
estimated_ms=1222.14990234375
estimated_ms=326.3006286621094
Please let me know if you have any other questions. Please feel free to close the bug if your question is answered.
yes i also modified the code similarly and got a perf gain in v100 too. one follow-up question is what's the differences between torch.compile(model, passes={"triton-autotune":True}) and torch.compile(model, backend="inductor"). does one use triton for matmul and the other dont? what's the default matmul kernel in inductor, isn't it a triton? but it seems the default mm is set to "aten" not the "triton" (link) how can I make sure I use the triton for matmuls?
@stephen-youn
backend="inductor"
uses TorchInductor backend. This is also the default backend, so torch.compile(model, passes={'triton-autotune":True})
is equivalent to torch.compile(model, backend="inductor", passes={'triton-autotune":True})
passes
argument can be used to setup TorchInductor flags. The triton_autotune flag is already set to True as default. triton_autotune
is not used for tuning matmul operations. It is used for tuning the fused kernels (pointwise, reduction, scatter etc). So, all of these are exactly same
torch.compile(mod)
torch.compile(mod, backend="inductor")
torch.compile(model, passes={"triton-autotune":True})
Reading between the lines, it seems you are interested in mm
operators. For those
aten
implementation for mm
/bmm
ops. We do not use Triton to generate the code for these matmul ops.passes={'triton-mm': True, 'triton-bmm': True}
. This part is not super heavily tested, so please be gentle. Do open issues if you see issues.I tried "opt_model = torch.compile(model, passes={'triton-mm': "triton", 'triton-bmm': True})" but it crashed. so i opened an issue here (link)
🐛 Describe the bug
Hi, I tried bert and resnet examples in the tutorial https://pytorch.org/blog/Accelerating-Hugging-Face-and-TIMM-models/ but it ran slower with the "torch.compile" with v100 under unbuntu env i have (i.e., Linux GCRHYP3C148 4.15.0-193-generic #204-Ubuntu SMP) isn't it supposed to be faster? thanks
Error logs
No response
Minified repro
""" resnet """
this runs like the following and the compiled model run 74x slower as shown below
it's similar for the following bert example in the tutorial. it's 14.7x slower with the extra line "model = torch.compile(model)"