Open sayakpaul opened 1 year ago
Hello - for the backend="tensorrt"
issue, please ensure the line import torch_tensorrt
is in the script (even though it may not be used directly). This is important because the import itself registers the backend, as here:
https://github.com/pytorch/TensorRT/blob/0112bb4996d464f46ae8027f7d7146eb045d6364/py/torch_tensorrt/dynamo/backend/backends.py#L27
If this does not address the issue, could you please also share the version of Torch-TensorRT being used?
Regarding the AoT approach, @peri044 may be able to help further, but based on the error it sounds like the model may not be natively trace-able. Which specific example from the presentation are you referencing for the AoT sample?
This one:
Hello @sayakpaul , I would suggest you to try our latest main branch. 23.09 is old now. Let us know if you see any issues with the main branch. Alternatively, you can try our nightly container if you don't want to build from source : https://github.com/pytorch/TensorRT/pkgs/container/tensorrt%2Ftorch_tensorrt
With the latest container (https://github.com/pytorch/TensorRT/pkgs/container/tensorrt%2Ftorch_tensorrt), the JIT workflow works. Have a few questions. So, let me use this thread to ask them. Maybe the community might find them to be useful.
I used the following code snippet I am using to benchmark the JIT workflow:
import torch
import torch_tensorrt
import torch.utils.benchmark as benchmark
class MyModel(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(64, 128, 3)
def forward(self, a):
if a.sum() > 3:
b = torch.nn.functional.elu(a)
else:
b = a + 1
return self.conv(b)
# Taken from
# https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html
def benchmark_fn(f, *args, **kwargs):
t0 = benchmark.Timer(
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
)
return t0.blocked_autorange().mean * 1e6
model = MyModel().eval().cuda()
inputs = torch.nn.randn(32, 64, 64, 64, device="cuda")
print("Using regular model:")
print(benchmark_fn(model, inputs))
optimized_model = torch.compile(
model,
backend="tensorrt",
dynamic=False,
options={
"debug": True,
"enabled_precisions": {torch.half},
"min_block_size": 1,
}
)
_ = optimized_model(inputs)
print("Using optimized model:")
print(benchmark_fn(optimized_model, inputs))
I am getting:
620.6526302793078 microseconds (for the non-compiled model)
2095.269737765193 (for the compiled model)
I am on an 80 GB A100. I can understand the compute load might not be enough for me to see evident speedups here, but is this expected? Wanted to know your opinion.
The AOT workflow still fails:
import torch
import torch_tensorrt
import torch.utils.benchmark as benchmark
class MyModel(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(64, 128, 3)
def forward(self, a):
if a.sum() > 3:
b = torch.nn.functional.elu(a)
else:
b = a + 1
return self.conv(b)
# Taken from
# https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html
def benchmark_fn(f, *args, **kwargs):
t0 = benchmark.Timer(
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
)
return t0.blocked_autorange().mean * 1e6
inputs = torch.randn((256, 64, 32, 32)).cuda()
model = MyModel().eval().cuda()
print("Using regular model:")
print(benchmark_fn(model, inputs))
exp_program = torch_tensorrt.dynamo.trace(model, [inputs])
trt_model = torch_tensorrt.dynamo.compile(exp_program, inputs=[inputs])
_ = trt_model(input)
print("Using optimized model:")
print(benchmark_fn(optimized_model, inputs))
print("Serializing...")
trt_ser_model = torch_tensorrt.dynamo.serialize(trt_model, *inputs)
torch.save(trt_ser_model, "trt_model.pt")
Trace:
Traceback (most recent call last):
File "/opt/torch_tensorrt/trt_aot.py", line 33, in <module>
exp_program = torch_tensorrt.dynamo.trace(model, [inputs])
File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch_tensorrt/dynamo/aten_tracer.py", line 51, in trace
if input.shape_mode == Input._ShapeMode.DYNAMIC:
AttributeError: 'Tensor' object has no attribute 'shape_mode'
When I swapped the inputs
with the following:
inputs = [torch_tensorrt.Input(
min_shape=(1, 512, 1, 1),
opt_shape=(4, 512, 1, 1),
max_shape=(8, 512, 1, 1)
)]
it led to the following:
Traceback (most recent call last):
File "/opt/torch_tensorrt/trt_aot.py", line 37, in <module>
exp_program = torch_tensorrt.dynamo.trace(model, inputs)
File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch_tensorrt/dynamo/aten_tracer.py", line 86, in trace
exp_program = export(model, tuple(trace_inputs), constraints=constraints)
File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/_export/__init__.py", line 556, in export
return _export(
File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/_export/__init__.py", line 596, in _export
gm_torch_level = _export_to_torch_ir(
File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/_export/__init__.py", line 517, in _export_to_torch_ir
gm_torch_level, _ = torch._dynamo.export(
File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 1226, in inner
result_traced = opt_f(*args, **kwargs)
File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1519, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1528, in _call_impl
return forward_call(*args, **kwargs)
File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 410, in _fn
return fn(*args, **kwargs)
File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1519, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1528, in _call_impl
return forward_call(*args, **kwargs)
File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 558, in catch_errors
return callback(frame, cache_entry, hooks, frame_state)
File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 148, in _fn
return fn(*args, **kwargs)
File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 402, in _convert_frame_assert
return _compile(
File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 610, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 221, in time_wrapper
r = func(*args, **kwargs)
File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 527, in compile_inner
out_code = transform_code_object(code, transform)
File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1028, in transform_code_object
transformations(instructions, code_options)
File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 497, in transform
tracer.run()
File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2127, in run
super().run()
File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 751, in run
and self.step()
File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 714, in step
getattr(self, inst.opname)(inst)
File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 382, in inner
raise exc.UserError(
torch._dynamo.exc.UserError: Dynamic control flow is not supported at the moment. Please use functorch.experimental.control_flow.cond to explicitly capture the control flow. For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#cond-operands
from user code:
File "/opt/torch_tensorrt/trt_aot.py", line 10, in forward
if a.sum() > 3:
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
I changed the code accordingly as well:
exp_program = torch_tensorrt.dynamo.trace(model, inputs)
trt_model = torch_tensorrt.dynamo.compile(exp_program, inputs=inputs)
Hi @sayakpaul - thank you for the comments. Regarding the JIT workflow, this specific example was intended to illustrate how graph breaks are handled in torch.compile
with the Torch-TRT backend. As you suggested, the model itself does not have sufficient operations to make the overhead of using a TRT engine worthwhile here. Additionally, this graph has an intentionally-added control-flow break, which means that a minimum of two TRT engines will be generated, further increasing the overhead for this small model. Since this model only has one computational operator per TRT block, the default options would not have converted these operators to TRT at all ("min_block_size": 1 overrides this default). As a result, I believe the metrics are roughly as expected here.
The AOT workflow is not expected to work for this specific model (the one with the conditional). It would only work for the model without the conditional, as here. This is because the Torch ATen tracer cannot handle this sort of Python control flow. In order to use the AOT workflow on a model like this, the code would need to be modified to either remove the conditional or use tracer-allowed conditionals, like torch.cond
Will try and get back. Thanks for your inputs!
Could you maybe provide an example for JIT where the improvements are evident? FWIW, we did try something with SD but the benefits were not there: https://github.com/huggingface/diffusers/issues/5564.
Probably we are missing something out?
Thanks for the follow-up! I will look into that sample you provided a bit more - it looks like the TRT conversion errored out for some reason, which would cause compilation to fall back to Torch eager. One other to try might be the Stable Diffusion version highlighted here.
Also, regarding the error cited in https://github.com/huggingface/diffusers/issues/5564, I retried your sample on the latest nightly version of the repository, with the following small code modifications, and it is working:
elif run_compile and with_tensorrt:
print("Run torch compile with TensorRT backend")
pipe.unet = torch.compile(
pipe.unet, fullgraph=True, backend="tensorrt", dynamic=False,
options={"min_block_size": 1,
"truncate_long_and_double": True,
"enabled_precisions": {torch.half}}
)
Please let me know if this resolves the issue for you as well.
I retried your sample on the latest nightly version of the repository
Don't want to build from source. Will https://github.com/pytorch/TensorRT/pkgs/container/tensorrt%2Ftorch_tensorrt work here? (I will prune the previous container before mounting)
Yes, I believe the Docker container should also work. You could alternatively also use the following for installing the latest nightly distribution of Torch-TRT:
pip install --pre torch torch-tensorrt --extra-index-url https://download.pytorch.org/whl/nightly/cu121
Hello @gs-olive! Apologies for the delay on my end.
I did run the SDXL sample script on the latest nightly container (ensuring that the previous one was pruned from the system before), but I am still not seeing the expected speedups (for a batch size of 4):
With compilation: False, and TensorRT: False in 23988496.810 microseconds
With compilation: True, and TensorRT: False in 20895331.856 microseconds
With compilation: True, and TensorRT: True in 32650460.551 microseconds
Here is the trace:
Traceback (most recent call last):
File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch_tensorrt/dynamo/backend/backends.py", line 95, in _pretraced_backend
trt_compiled = compile_module(
File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch_tensorrt/dynamo/_compiler.py", line 304, in compile_module
trt_module = convert_module(
File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch_tensorrt/dynamo/conversion/_conversion.py", line 34, in convert_module
module_outputs = module(*torch_inputs)
File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/fx/graph_module.py", line 736, in call_wrapped
return self._wrapped_call(self, *args, **kwargs)
File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/fx/graph_module.py", line 315, in __call__
raise e
File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/fx/graph_module.py", line 302, in __call__
return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc]
File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1519, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1528, in _call_impl
return forward_call(*args, **kwargs)
File "<eval_with_key>.14", line 6, in forward
view_14 = torch.ops.aten.view.default(permute_15, [8, -1, 640]); permute_15 = None
File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/_ops.py", line 513, in __call__
return self._op(*args, **kwargs or {})
File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/utils/_stats.py", line 20, in wrapper
return fn(*args, **kwargs)
File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1378, in __torch_dispatch__
return self.dispatch(func, types, args, kwargs)
File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1676, in dispatch
r = func(*args, **kwargs)
File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/_ops.py", line 513, in __call__
return self._op(*args, **kwargs or {})
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
Are you seeing anything different?
I am using a 80 GB A100 FYI. I also modified the code with your suggestion and below is the full code snippet:
import argparse
import torch
import torch_tensorrt
import torch.utils.benchmark as benchmark
from diffusers import DiffusionPipeline
CKPT = "stabilityai/stable-diffusion-xl-base-1.0"
NUM_INFERENCE_STEPS = 50
PROMPT = "ghibli style, a fantasy landscape with castles"
def load_pipeline(run_compile=False, with_tensorrt=False):
pipe = DiffusionPipeline.from_pretrained(
CKPT, torch_dtype=torch.float16, use_safetensors=True
)
pipe = pipe.to("cuda")
pipe.unet.to(memory_format=torch.channels_last)
if run_compile and not with_tensorrt:
print("Run torch compile")
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
elif run_compile and with_tensorrt:
print("Run torch compile with TensorRT backend")
pipe.unet = torch.compile(
pipe.unet, fullgraph=True, backend="tensorrt", dynamic=False,
options={"min_block_size": 1,
"truncate_long_and_double": True,
"enabled_precisions": {torch.half}}
)
pipe.set_progress_bar_config(disable=True)
return pipe
def run_inference(pipe, batch_size=1):
_ = pipe(
prompt=PROMPT,
num_inference_steps=NUM_INFERENCE_STEPS,
num_images_per_prompt=batch_size,
)
# Taken from
# https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html
def benchmark_fn(f, *args, **kwargs):
t0 = benchmark.Timer(
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
)
return t0.blocked_autorange().mean * 1e6
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--run_compile", action="store_true")
parser.add_argument("--with_tensorrt", action="store_true")
args = parser.parse_args()
pipeline = load_pipeline(
run_compile=args.run_compile, with_tensorrt=args.with_tensorrt
)
print(
f"With compilation: {args.run_compile}, and TensorRT: {args.with_tensorrt} in {benchmark_fn(run_inference, pipeline, args.batch_size):.3f} microseconds"
)
Hi @sayakpaul - thanks for the details. I am able to reproduce the issue you are facing, but it only appears to happen with the latest nightly torch
distribution. Specifically, I do not see the error when using torch==2.1.0
. We will be pushing a Docker container with that build later today. I will update this issue when that is ready. It seems that something changed in the Torch tracer which resulted in this error.
Additionally, we do not have full converter support yet for this specific model on the latest Torch nightly (but we do on 2.1.0
). The missing operator we need to implement is aten._scaled_dot_product_flash_attention
and I have instanced an issue for this #2427, so we can prioritize it.
Cool, let me know know :-) Looking forward to it. I will try to fix the SHA of the Docker container so that it's reproducible for the community.
@gs-olive a gentle ping if the container is available :)
Thanks for the message and apologies for the delay - I expect it to be ready at some point today or tomorrow and I will update this issue once it is ready.
Hi @sayakpaul - the container built against PyTorch 2.1.0
can be found here:
ghcr.io/pytorch/tensorrt/torch_tensorrt:release_2.1
This is what I get now:
Traceback (most recent call last):
File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch_tensorrt/dynamo/backend/backends.py", line 95, in _pretraced_backend
trt_compiled = compile_module(
File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch_tensorrt/dynamo/_compiler.py", line 312, in compile_module
trt_module = convert_module(
File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch_tensorrt/dynamo/conversion/_conversion.py", line 33, in convert_module
module_outputs = module(*torch_inputs)
File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/fx/graph_module.py", line 678, in call_wrapped
return self._wrapped_call(self, *args, **kwargs)
File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/fx/graph_module.py", line 284, in __call__
raise e
File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/fx/graph_module.py", line 274, in __call__
return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc]
File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "<eval_with_key>.14", line 6, in forward
view_14 = torch.ops.aten.view.default(permute_15, [8, -1, 640]); permute_15 = None
File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/_ops.py", line 448, in __call__
return self._op(*args, **kwargs or {})
File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/utils/_stats.py", line 20, in wrapper
return fn(*args, **kwargs)
File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1250, in __torch_dispatch__
return self.dispatch(func, types, args, kwargs)
File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1541, in dispatch
r = func(*args, **kwargs)
File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/_ops.py", line 448, in __call__
return self._op(*args, **kwargs or {})
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
Ok - thanks for the update. Could you also try with transformers==4.33.2
and diffusers==0.21.4
? I am trying to determine whether the transformers
and diffusers
versions may also be related to the issue.
Sorry for the delay. It still remains the same.
On Tue, Nov 7, 2023, 14:08 George S @.***> wrote:
Ok - thanks for the update. Could you also try with transformers==4.33.2? I am trying to determine whether the transformers version may also be related to the issue.
— Reply to this email directly, view it on GitHub https://github.com/pytorch/TensorRT/issues/2415#issuecomment-1798041142, or unsubscribe https://github.com/notifications/unsubscribe-auth/AFPE2TDTFP2UC4A5RWSKM23YDHXOXAVCNFSM6AAAAAA6Q446XWVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTOOJYGA2DCMJUGI . You are receiving this because you were mentioned.Message ID: @.***>
@gs-olive I have a chance to do the benchmarking for Stable Diffusion. Here are the timings:
With compilation: False, and TensorRT: False in 3.767 seconds
With compilation: True, and TensorRT: False in 3.045 seconds
With compilation: True, and TensorRT: True in 1.157 seconds
Env:
transformers==4.33.2 diffusers==0.21.4
Will see with SDXL too.
Also, a bit suprising that with diffusers
and transformers
latest stable releases timings are quite off:
With compilation: False, and TensorRT: False in 1.708 seconds
With compilation: True, and TensorRT: False in 1.447 seconds
With compilation: True, and TensorRT: True in 2.379 seconds
But if I keep the transformers
version to 4.33.2 and diffusers
to the latest stable release then the timings are:
With compilation: False, and TensorRT: False in 3.512 seconds
With compilation: True, and TensorRT: False in 3.061 seconds
With compilation: True, and TensorRT: True in 1.181 seconds
Are you able to reproduce these numbers on a 80GB A100?
I also extended the above the script to use SDXL and the timings (updated here):
With compilation: False, and TensorRT: False in 6.713 seconds
With compilation: True, and TensorRT: False in 6.417 seconds
With compilation: True, and TensorRT: True in 5.537 seconds
I have set transformers
version to 4.33.2 and I am using diffusers
latest.
Thanks for the follow-up. It seems that in the newer transformers
and diffusers
versions, some of the operators have changed in the Stable Diffusion models, so our converter support is different which can affect performance. We have issues filed to support these as well, so the performance can be more uniform across transformers
and diffusers
versions.
❓ Question
I am within the
nvcr.io/nvidia/pytorch:23.09-py3
container. Trying out some snippets from: https://youtu.be/eGDMJ3MY4zk?si=MhkbgwAPVQSFZEha.Both JIT and AoT examples failed. For JIT, it complained that "tensorrt" backend isn't available, for AoT, it complained that "The user code is using a feature we don't support. Please try torchdynamo.explain() to get possible the reasons".
I am on an A100. What's going on?