pytorch / TensorRT

PyTorch/TorchScript/FX compiler for NVIDIA GPUs using TensorRT
https://pytorch.org/TensorRT
BSD 3-Clause "New" or "Revised" License
2.54k stars 349 forks source link

šŸ› [Bug] Unable to freeze tensor of type Int64/Float64 into constant layer #2848

Open supermeng opened 4 months ago

supermeng commented 4 months ago

Unable to freeze tensor of type Int64/Float64 into constant layer, try to compile model with truncate_long_and_double enabled

When I try to test the Transformer Attention layer with tensorRT, I get the error above. I do checked both the sample and input tensor and the inputs for trt.compile, there are no double tensor.

To Reproduce

Steps to reproduce the behavior:

  1. Just try with the following test code:
import torch
from torch import nn
import torch_tensorrt

from diffusers.models.attention import Attention

class AttnModule(nn.Module):
    def __init__(self):
        super().__init__()
        num_attention_heads = 16
        attention_head_dim = 8
        dim = num_attention_heads * attention_head_dim
        dropout = 0.0
        attention_bias = False
        upcast_attention = False
        attention_out_bias = True

        self.attn = Attention(
            query_dim=dim,
            heads=num_attention_heads,
            dim_head=attention_head_dim,
            dropout=dropout,
            bias=attention_bias,
            cross_attention_dim=None,
            upcast_attention=upcast_attention,
            out_bias=attention_out_bias,
        )

    def forward(self, sample: torch.Tensor):
        return self.attn(sample)

model = AttnModule().to(device='cuda').eval()  # torch module needs to be in eval (not training) mode
model = model.half()

a = torch.randn((1, 128, 128), device='cuda').half()
traced_model = torch.jit.trace(model, a).half().cuda()

print('traced_model', traced_model.graph)

enabled_precisions = {torch.half}  # Run with fp16

with torch_tensorrt.logging.debug():
    trt_ts_module = torch_tensorrt.compile(
    #traced_model, inputs=[t debug=True,orch_tensorrt.Input((1, 128, 128), dtype=torch.half, name="sample")], enabled_precisions=enabled_precisions, truncate_long_and_double = True
        traced_model, debug=True, ir="torchscript", inputs=[torch_tensorrt.Input([1, 128, 128], dtype=torch.half, name="sample")], enabled_precisions=enabled_precisions
    )

c = torch.randn((1, 128, 128), device='cuda').half()
# warm up
model(c)
traced_model(c)
trt_ts_module(c)

import time
start = time.time()
for i in range(100):
    with torch.no_grad():
        result = model(c)
torch.cuda.synchronize()
print('cost0:', time.time() - start)

start = time.time()
for i in range(100):
    with torch.no_grad():
        result = traced_model(c)
torch.cuda.synchronize()
print('cost1:', time.time() - start)

start = time.time()
for i in range(100):
    with torch.no_grad():
        result = trt_ts_module(c)
torch.cuda.synchronize()
print('cost2:', time.time() - start)

Expected behavior

Code run correctly

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

Additional context

tensor_rt_attn.log

narendasan commented 4 months ago

When looking at your reproducer, I noticed that you had truncate_long_and_double enabled earlier but have it commented out? When I try running it through the torchscript frontend with that feature enabled on main seems like it works fine? Also if you are tracing to work around torchscript limitations you might want to use the dynamo frontend, if you still need torchscript at the end, you can torch.jit.trace the output of torch-tensorrt but you will have access to all the latest features we have been adding.

supermeng commented 4 months ago

@narendasan Hiļ¼Œthanks so much for your reply, If I enable the truncate_long_and_double there would be another dtype dismatch(float with half) error. And what confused me is that there are no double dtype in all tensor calculations. Also it would cost much more time than eager or torch.jit.trace mode when I am using dynamo frontend.

supermeng commented 4 months ago
image
narendasan commented 4 months ago

@narendasan Hiļ¼Œthanks so much for your reply, If I enable the truncate_long_and_double there would be another dtype dismatch(float with half) error. And what confused me is that there are no double dtype in all tensor calculations. Also it would cost much more time than eager or torch.jit.trace mode when I am using dynamo frontend.

There may be int64 types in your code (including things like index) which require the use of that setting.