pytorch / TensorRT

PyTorch/TorchScript/FX compiler for NVIDIA GPUs using TensorRT
https://pytorch.org/TensorRT
BSD 3-Clause "New" or "Revised" License
2.59k stars 350 forks source link

πŸ› [Bug] Encountered bug when using Torch-TensorRT--list inputs #3143

Open yjjinjie opened 2 months ago

yjjinjie commented 2 months ago

Bug Description

In real-world scenarios, user features are constantly changing, so I must use a list as the input for the forward function. but when I use list input, the torch_tensorrt.dynamo.compile(exp_program, [inputs],min_block_size=2) raise error

To Reproduce

import torch
import torch_tensorrt
from typing import Optional, Sequence,Dict,List
from torch.nn import functional as F
from tzrec.modules.mlp import MLP
from torch import nn

@torch.fx.wrap
def _get_dict(grouped_features_keys: List[str], args:List[torch.Tensor])->Dict[str, torch.Tensor]:
    if len(grouped_features_keys) != len(args):
            raise ValueError(
                "The number of grouped_features_keys must match "
                "the number of arguments."
            )
    grouped_features = {
        key: value for key, value in zip(grouped_features_keys, args)
    }
    return grouped_features

@torch.fx.wrap
def _arange(end: int, device: torch.device) -> torch.Tensor:
    return torch.arange(end, device=device)

class MatMul2(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.keys = ["query","sequence","sequence_length"]
        attn_mlp= {'hidden_units': [256, 64], 'dropout_ratio': [], 'activation': 'nn.ReLU', 'use_bn': False}
        self.mlp = MLP(in_features=41 * 4, **attn_mlp)
        self.linear = nn.Linear(self.mlp.hidden_units[-1], 1)

    def forward(self, args1: List[torch.Tensor]):
        """Forward the module."""
        # use predict to avoid trace error in self._output_to_prediction(y)
        return self.predict(args1)

    def predict(self, args: List[torch.Tensor]):
        grouped_features= _get_dict(self.keys, args)
        query = grouped_features["query"]
        sequence = grouped_features["sequence"]
        sequence_length = grouped_features["sequence_length"]
        max_seq_length = sequence.size(1)
        sequence_mask = _arange(
            max_seq_length, device=sequence_length.device
        ).unsqueeze(0) < sequence_length.unsqueeze(1)

        queries = query.unsqueeze(1).expand(-1, max_seq_length, -1)

        # attn_input = torch.cat(
        #     [queries, sequence, queries - sequence, queries * sequence], dim=-1
        # )

        return queries

model = MatMul2().eval().cuda()
a1=torch.randn(2, 41).cuda()
b1=torch.randn(2, 50,41).cuda()
c1=torch.randn(2).cuda()
inputs=[a1,b1,c1]
exp_program = torch.export.export(model, (inputs,))
# # ERROR
trt_gm = torch_tensorrt.dynamo.compile(exp_program, [inputs],min_block_size=2)
# # Run inference
# print(trt_gm(*inputs))

ERROR

Traceback (most recent call last):
  File "/larec/tzrec/tests/test_2.py", line 64, in <module>
    trt_gm = torch_tensorrt.dynamo.compile(exp_program, [inputs],min_block_size=2)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch_tensorrt/dynamo/_compiler.py", line 230, in compile
    trt_gm = compile_module(gm, inputs, settings)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch_tensorrt/dynamo/_compiler.py", line 427, in compile_module
    sample_outputs = gm(
                     ^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/fx/graph_module.py", line 738, in call_wrapped
    return self._wrapped_call(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/fx/graph_module.py", line 316, in __call__
    raise e
  File "/opt/conda/lib/python3.11/site-packages/torch/fx/graph_module.py", line 303, in __call__
    return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1582, in _call_impl
    args_kwargs_result = hook(self, args, kwargs)  # type: ignore[misc]
                         ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 600, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/export/_unlift.py", line 33, in _check_input_constraints_pre_hook
    return _check_input_constraints_for_graph(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_export/utils.py", line 86, in _check_input_constraints_for_graph
    raise RuntimeError(
RuntimeError: Expected input at *args[0][0] to be a tensor, but got <class 'torch_tensorrt._Input.Input'>

Environment

CPU(s):                          104
On-line CPU(s) list:             0-103
Vendor ID:                       GenuineIntel
Model name:                      Intel(R) Xeon(R) Platinum 8269CY CPU @ 2.50GHz
CPU family:                      6
Model:                           85
Thread(s) per core:              2
Core(s) per socket:              26
Socket(s):                       2
Stepping:                        7
CPU max MHz:                     3800.0000
CPU min MHz:                     1200.0000
BogoMIPS:                        5000.00
Flags:                           fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cdp_l3 invpcid_single intel_ppin ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm mpx rdt_a avx512f avx512dq rdseed adx smap clflushopt clwb intel_pt avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local dtherm ida arat pln pts pku ospke avx512_vnni md_clear flush_l1d arch_capabilities
Virtualization:                  VT-x
L1d cache:                       1.6 MiB (52 instances)
L1i cache:                       1.6 MiB (52 instances)
L2 cache:                        52 MiB (52 instances)
L3 cache:                        71.5 MiB (2 instances)
NUMA node(s):                    1
NUMA node0 CPU(s):               0-103
Vulnerability Itlb multihit:     KVM: Mitigation: Split huge pages
Vulnerability L1tf:              Not affected
Vulnerability Mds:               Not affected
Vulnerability Meltdown:          Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:        Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:        Mitigation; Enhanced IBRS, IBPB conditional, RSB filling
Vulnerability Tsx async abort:   Mitigation; TSX disabled

Versions of relevant libraries:
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.26.4
[pip3] optree==0.12.1
[pip3] torch==2.4.0
[pip3] torch_tensorrt==2.4.0
[pip3] torchaudio==2.4.0
[pip3] torchelastic==0.2.2
[pip3] torchmetrics==1.0.3
[pip3] torchrec==0.8.0+cu121
[pip3] torchvision==0.19.0
[pip3] triton==3.0.0
[conda] blas                      1.0                         mkl
[conda] ffmpeg                    4.3                  hf484d3e_0    pytorch
[conda] libjpeg-turbo             2.0.0                h9bf148f_0    pytorch
[conda] mkl                       2023.1.0         h213fc3f_46344
[conda] mkl-service               2.4.0           py311h5eee18b_1
[conda] mkl_fft                   1.3.8           py311h5eee18b_0
[conda] mkl_random                1.2.4           py311hdb19cb5_0
[conda] numpy                     1.26.4          py311h08b1b3b_0
[conda] numpy-base                1.26.4          py311hf175353_0
[conda] optree                    0.12.1                   pypi_0    pypi
[conda] pytorch                   2.4.0           py3.11_cuda12.1_cudnn9.1.0_0    pytorch
[conda] pytorch-cuda              12.1                 ha16c6d3_5    pytorch
[conda] pytorch-mutex             1.0                        cuda    pytorch
[conda] torch-tensorrt            2.4.0                    pypi_0    pypi
[conda] torchaudio                2.4.0               py311_cu121    pytorch
[conda] torchelastic              0.2.2                    pypi_0    pypi
[conda] torchmetrics              1.0.3                    pypi_0    pypi
[conda] torchrec                  0.8.0+cu121              pypi_0    pypi
[conda] torchtriton               3.0.0                     py311    pytorch
[conda] torchvision               0.19.0              py311_cu121    pytorch
narendasan commented 2 months ago

@peri044 can you look at the export workflow here?