pytorch / TensorRT

PyTorch/TorchScript/FX compiler for NVIDIA GPUs using TensorRT
https://pytorch.org/TensorRT
BSD 3-Clause "New" or "Revised" License
2.52k stars 348 forks source link

❓ [Question] mlp running with torch_tensorrt slower than with inductor? #2606

Open johnzlli opened 8 months ago

johnzlli commented 8 months ago

❓ Question

I am within the nvcr.io/nvidia/pytorch:23.12-py3 container. The performance of torch_tensorrt is wrose than inductor. Details: example code

import torch
import torch_tensorrt
import torch.nn as nn

class MLPBlocks(nn.Module):
    def __init__(self, window_dim, hidden_dim):
        super().__init__()

        self.mlp_1 = nn.Sequential(
            nn.Linear(window_dim, window_dim * 4),
            nn.ReLU(),
            nn.Linear(window_dim * 4, window_dim),
        )
        self.mlp_2 = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
        )

    def forward(self, x):
        x = self.mlp_1(x.transpose(1, 2)).transpose(1, 2)
        x = self.mlp_2(x)
        return x

class MLP(nn.Module):
    def __init__(self, *_args):
        super(MLP, self).__init__()
        self.hidden_dim = 256
        self.window_dim = 50
        self.n_feature = 800

        self.fc_first = nn.Linear(self.n_feature, self.hidden_dim)
        self.fc_last = nn.Linear(self.hidden_dim, 1)
        self.blocks = nn.ModuleList([MLPBlocks(window_dim=self.window_dim, hidden_dim=self.hidden_dim) for _ in range(8)])

    def forward(self, input_x):
        net_x = self.fc_first(input_x.transpose(0, 1))
        for mlp_block in self.blocks:
            net_x = mlp_block(net_x)
        net_x = self.fc_last(torch.mean(net_x, dim=1))
        return net_x

def run_model(x, model):
    for _ in range(10):
        with torch.no_grad():
            res = model(x)

    torch.cuda.synchronize()
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    start.record()

    for i in range(50):
        with torch.no_grad():
            res = model(x)

    end.record()
    torch.cuda.synchronize()
    return start.elapsed_time(end)/50

def test_inductor(data, model):
    x = data.float().cuda()
    m = model.float().cuda()
    torch._dynamo.reset()
    opt_model = torch.compile(m)
    print(f"inductor fp32 time: {run_model(x, opt_model)}")

    x = x.half()
    m = m.half()
    torch._dynamo.reset()
    opt_model = torch.compile(m)
    print(f"inductor fp16 time: {run_model(x, opt_model)}")

def test_trt_script(data, model):
    x = data.float().cuda()
    m = model.float().cuda()
    script_model = torch.jit.trace(m, x)
    trt_ts_model = torch_tensorrt.compile(script_model, ir="torchscript", inputs=[x], enabled_precisions={torch.float})
    print(f"trt_script fp32 time: {run_model(x, trt_ts_model)}")

    x = x.half()
    m = m.half()
    script_model = torch.jit.trace(m, x)
    trt_ts_model = torch_tensorrt.compile(script_model, ir="torchscript", inputs=[x], enabled_precisions={torch.half})
    print(f"trt script fp16 time: {run_model(x, trt_ts_model)}")

def test_trt_dynamo(data, model):
    x = data.float().cuda()
    m = model.float().cuda()
    torch._dynamo.reset()
    opt_model = torch_tensorrt.compile(m, ir="torch_compile", inputs=[x], enabled_precisions={torch.float})
    print(f"trt_dynamo fp32 time: {run_model(x, opt_model)}")

    x = data.half().cuda()
    m = model.half().cuda()
    torch._dynamo.reset()
    opt_model = torch_tensorrt.compile(m, ir="torch_compile", inputs=[x], enabled_precisions={torch.half})
    print(f"trt_dynamo fp16 time: {run_model(x, opt_model)}")

if __name__ == "__main__":
    model = MLP()
    x = torch.randn(50, 5000, 800)
    test_inductor(x, model)
    test_trt_script(x, model)
    test_trt_dynamo(x, model)

result 8f95d9cf-d710-44fe-b1e6-d21f97e08032

What you have already tried

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

Additional context

gs-olive commented 8 months ago

Hello - since this model is traceable and doesn't appear to have graph breaks, I think ir="dynamo" can generally give a small boost over ir="torch_compile". Additionally, there is an optimization_level parameter for which the maximum is 5. I have added an adapted example below which could help boost performance:

    x = data.half().cuda()
    m = model.half().cuda()
    torch._dynamo.reset()
    opt_model = torch_tensorrt.compile(m, ir="dynamo", inputs=[x], enabled_precisions={torch.half}, optimization_level=5)
    print(f"trt_dynamo fp16 time: {run_model(x, opt_model)}")

Additionally, if you share the output logs of a (separate) run with debug=True, we can see if any operators in the model are unsupported, which can also affect performance.

johnzlli commented 8 months ago

Hello - since this model is traceable and doesn't appear to have graph breaks, I think ir="dynamo" can generally give a small boost over ir="torch_compile". Additionally, there is an optimization_level parameter for which the maximum is 5. I have added an adapted example below which could help boost performance:

    x = data.half().cuda()
    m = model.half().cuda()
    torch._dynamo.reset()
    opt_model = torch_tensorrt.compile(m, ir="dynamo", inputs=[x], enabled_precisions={torch.half}, optimization_level=5)
    print(f"trt_dynamo fp16 time: {run_model(x, opt_model)}")

Additionally, if you share the output logs of a (separate) run with debug=True, we can see if any operators in the model are unsupported, which can also affect performance.

Thanks for your reply! I take your advice but it seems that ir="dynamo" and optimization_level=5 get even worse performance than before. And i am sorry, due to the internet access control of the server, i can't share the log file. However, my code is fully displayed above. Perhaps, you can make a copy and run it to try it out.

gs-olive commented 8 months ago

Thanks for the follow-up. It appears we have full coverage for that model and all of the operators are effectively converted to TRT. I would also suggest using the latest nightly version of Torch-TRT for the most up-to-date performance additions, which can be installed from source or via pip:

pip install --pre torch torchvision torch_tensorrt  --index-url https://download.pytorch.org/whl/nightly/cu121