In real-world scenarios, user features are constantly changing, so I must use a list as the input for the forward function.
but when I use list input, the torch_tensorrt.dynamo.compile(exp_program, [inputs],min_block_size=2) raise error
To Reproduce
import torch
import torch_tensorrt
from typing import Optional, Sequence,Dict,List
from torch.nn import functional as F
from tzrec.modules.mlp import MLP
from torch import nn
@torch.fx.wrap
def _get_dict(grouped_features_keys: List[str], args:List[torch.Tensor])->Dict[str, torch.Tensor]:
if len(grouped_features_keys) != len(args):
raise ValueError(
"The number of grouped_features_keys must match "
"the number of arguments."
)
grouped_features = {
key: value for key, value in zip(grouped_features_keys, args)
}
return grouped_features
@torch.fx.wrap
def _arange(end: int, device: torch.device) -> torch.Tensor:
return torch.arange(end, device=device)
class MatMul2(torch.nn.Module):
def __init__(self):
super().__init__()
self.keys = ["query","sequence","sequence_length"]
attn_mlp= {'hidden_units': [256, 64], 'dropout_ratio': [], 'activation': 'nn.ReLU', 'use_bn': False}
self.mlp = MLP(in_features=41 * 4, **attn_mlp)
self.linear = nn.Linear(self.mlp.hidden_units[-1], 1)
def forward(self, args1: List[torch.Tensor]):
"""Forward the module."""
# use predict to avoid trace error in self._output_to_prediction(y)
return self.predict(args1)
def predict(self, args: List[torch.Tensor]):
grouped_features= _get_dict(self.keys, args)
query = grouped_features["query"]
sequence = grouped_features["sequence"]
sequence_length = grouped_features["sequence_length"]
max_seq_length = sequence.size(1)
sequence_mask = _arange(
max_seq_length, device=sequence_length.device
).unsqueeze(0) < sequence_length.unsqueeze(1)
queries = query.unsqueeze(1).expand(-1, max_seq_length, -1)
# attn_input = torch.cat(
# [queries, sequence, queries - sequence, queries * sequence], dim=-1
# )
return queries
model = MatMul2().eval().cuda()
a1=torch.randn(2, 41).cuda()
b1=torch.randn(2, 50,41).cuda()
c1=torch.randn(2).cuda()
inputs=[a1,b1,c1]
exp_program = torch.export.export(model, (inputs,))
# # ERROR
trt_gm = torch_tensorrt.dynamo.compile(exp_program, [inputs],min_block_size=2)
# # Run inference
# print(trt_gm(*inputs))
ERROR
Traceback (most recent call last):
File "/larec/tzrec/tests/test_2.py", line 64, in <module>
trt_gm = torch_tensorrt.dynamo.compile(exp_program, [inputs],min_block_size=2)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch_tensorrt/dynamo/_compiler.py", line 230, in compile
trt_gm = compile_module(gm, inputs, settings)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch_tensorrt/dynamo/_compiler.py", line 427, in compile_module
sample_outputs = gm(
^^^
File "/opt/conda/lib/python3.11/site-packages/torch/fx/graph_module.py", line 738, in call_wrapped
return self._wrapped_call(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/fx/graph_module.py", line 316, in __call__
raise e
File "/opt/conda/lib/python3.11/site-packages/torch/fx/graph_module.py", line 303, in __call__
return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1582, in _call_impl
args_kwargs_result = hook(self, args, kwargs) # type: ignore[misc]
^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 600, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/export/_unlift.py", line 33, in _check_input_constraints_pre_hook
return _check_input_constraints_for_graph(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/_export/utils.py", line 86, in _check_input_constraints_for_graph
raise RuntimeError(
RuntimeError: Expected input at *args[0][0] to be a tensor, but got <class 'torch_tensorrt._Input.Input'>
Bug Description
In real-world scenarios, user features are constantly changing, so I must use a list as the input for the forward function. but when I use list input, the torch_tensorrt.dynamo.compile(exp_program, [inputs],min_block_size=2) raise error
To Reproduce
ERROR
Environment