Open yjjinjie opened 2 months ago
@narendasan can you help me slove these problem? I want to set the dynamic shape in batch size & seq_len
@narendasan when to support torch_executed_modules in dynamo mode?
Hi @yjjinjie you can set the dynamic shapes and pass in the dynamic inputs using torch_tensorrt.Input
something like
compile_spec = {
"inputs": [
torch_tensorrt.Input(
min_shape=(1, 3, 224, 224),
opt_shape=(8, 3, 224, 224),
max_shape=(16, 3, 224, 224),
dtype=torch.half,
)
],
"enabled_precisions": enabled_precisions,
"ir": "dynamo",
}
trt_model = torch_tensorrt.compile(model, **compile_spec)
where model
is your torch trt compiled module. You can refer to the example- https://github.com/pytorch/TensorRT/blob/main/examples/dynamo/torch_compile_resnet_example.py
Since you want to set batch_size and seq_len as dynamic, you need to pass their range. eg:
torch_tensorrt.Input(
min_shape=(1, 1, 224, 224),
opt_shape=(8, 2, 224, 224),
max_shape=(16, 3, 224, 224),
dtype=torch.half,
)
where the first two (1, 8, 16) and (1, 2, 3) denote the batch_size and seq_len respectively. Can you try with this and see if you get the same error as above?
yes,I have tried the torch_tensorrt.Input. but it encountered a new bug
import torch
import torch_tensorrt
from typing import Optional, Sequence,Dict,List
from torch.nn import functional as F
from tzrec.modules.mlp import MLP
from torch import nn
@torch.fx.wrap
def _get_dict(grouped_features_keys: List[str], args:List[torch.Tensor])->Dict[str, torch.Tensor]:
if len(grouped_features_keys) != len(args):
raise ValueError(
"The number of grouped_features_keys must match "
"the number of arguments."
)
grouped_features = {
key: value for key, value in zip(grouped_features_keys, args)
}
return grouped_features
@torch.fx.wrap
def _arange(end: int, device: torch.device) -> torch.Tensor:
return torch.arange(end, device=device)
class MatMul(torch.nn.Module):
def __init__(self):
super().__init__()
self.keys = ["query","sequence","sequence_length"]
attn_mlp= {'hidden_units': [256, 64], 'dropout_ratio': [], 'activation': 'nn.ReLU', 'use_bn': False}
self.mlp = MLP(in_features=41 * 4, **attn_mlp)
self.linear = nn.Linear(self.mlp.hidden_units[-1], 1)
def forward(self, *args1: List[torch.Tensor]):
"""Forward the module."""
# use predict to avoid trace error in self._output_to_prediction(y)
return self.predict(args1)
def predict(self, args: List[torch.Tensor]):
grouped_features= _get_dict(self.keys, args)
query = grouped_features["query"]
sequence = grouped_features["sequence"]
sequence_length = grouped_features["sequence_length"]
max_seq_length = sequence.size(1)
sequence_mask = _arange(
max_seq_length, device=sequence_length.device
).unsqueeze(0) < sequence_length.unsqueeze(1)
queries = query.unsqueeze(1).expand(-1, max_seq_length, -1)
attn_input = torch.cat(
[queries, sequence, queries - sequence, queries * sequence], dim=-1
)
return attn_input
model = MatMul().eval().cuda()
a=torch.randn(2, 41).cuda()
b=torch.randn(2, 2,41).cuda()
c=torch.randn(2).cuda()
d=torch.tensor(2)
# torch._dynamo.mark_dynamic(a, 0, min=1, max=8196)
# torch._dynamo.mark_dynamic(b, 0, min=1, max=8196)
# torch._dynamo.mark_dynamic(b, 1, min=1, max=50)
# torch._dynamo.mark_dynamic(c, 0, min=1, max=8196)
# torch._dynamo.mark_dynamic(d, 0, min=1, max=8196)
inputs = [a, b,c]
print(model(*inputs)[0][0][0])
# seq_len = torch.export.Dim("seq_len", min=1, max=10)
# dynamic_shapes=({2: seq_len}, {2: seq_len})
# Export the model first with custom dynamic shape constraints
from torchrec.fx import symbolic_trace
model = symbolic_trace(model)
inputs_dy = []
inputs_dy.append(
torch_tensorrt.Input(
min_shape=[1, 41],
opt_shape=[512, 41],
max_shape=[8196, 41],
name="query",
)
)
inputs_dy.append(
torch_tensorrt.Input(
min_shape=[1, 1,41],
opt_shape=[512, 2, 41],
max_shape=[8196,50, 41],
name="sequence",
)
)
inputs_dy.append(
torch_tensorrt.Input(
min_shape=[1],
opt_shape=[512],
max_shape=[8196],
name="sequence_length",
)
)
trt_gm = torch_tensorrt.compile(
model,
ir="dynamo",
inputs=[inputs_dy],min_block_size=1,
torch_executed_ops=["aten.expand"],)
print(trt_gm)
# exp_program = torch.export.export(model, (*inputs,))
# trt_gm = torch_tensorrt.dynamo.compile(exp_program, inputs,assume_dynamic_shape_support=True,
# allow_shape_tensors=True,min_block_size=2)
# Run inference
print(trt_gm(*inputs)[0][0][0])
# trt_gm = symbolic_trace(trt_gm)
trt_gm = torch.jit.trace(trt_gm,
example_inputs=(a,b,c),
strict=False)
scripted_model = torch.jit.script(trt_gm)
scripted_model.save("./scripted_model_trt.pt")
model_gpu = torch.jit.load(
"./scripted_model_trt.pt", map_location="cuda:0"
)
from torch.profiler import ProfilerActivity, profile, record_function
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
record_shapes=True,
) as prof:
with record_function("model_inference"):
model_gpu(*inputs)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=100))
print("load:",model_gpu(*inputs)[0][0][0])
the error is:
Traceback (most recent call last):
File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 805, in step
self.dispatch_table[inst.opcode](self, inst)
File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1266, in RAISE_VARARGS
raise exc.ObservedException(f"raised exception {val}")
torch._dynamo.exc.ObservedException: raised exception ExceptionVariable()
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 805, in step
self.dispatch_table[inst.opcode](self, inst)
File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 499, in wrapper
return inner_fn(self, inst)
^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2059, in CALL
self.call_function(fn, args, kwargs)
File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 743, in call_function
self.push(fn.call_function(self, args, kwargs))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 293, in call_function
return super().call_function(tx, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 90, in call_function
return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 749, in inline_user_function_return
return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2666, in inline_call
return cls.inline_call_(parent, func, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2782, in inline_call_
tracer.run()
File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 893, in run
while self.step():
^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 808, in step
self.exception_handler()
File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1304, in exception_handler
raise exc.ObservedException
torch._dynamo.exc.ObservedException
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/larec/tzrec/tests/test_dy2.py", line 103, in <module>
trt_gm = torch_tensorrt.compile(
^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch_tensorrt/_compile.py", line 248, in compile
exp_program = dynamo_trace(module, torchtrt_inputs, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch_tensorrt/dynamo/_tracer.py", line 81, in trace
exp_program = export(mod, tuple(torch_inputs), dynamic_shapes=tuple(dynamic_shapes))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/export/__init__.py", line 174, in export
return _export(
^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/export/_trace.py", line 945, in wrapper
raise e
File "/opt/conda/lib/python3.11/site-packages/torch/export/_trace.py", line 928, in wrapper
ep = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/export/exported_program.py", line 89, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/export/_trace.py", line 1455, in _export
aten_export_artifact = export_func(
^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/export/_trace.py", line 1060, in _strict_export
gm_torch_level = _export_to_torch_ir(
^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/export/_trace.py", line 512, in _export_to_torch_ir
gm_torch_level, _ = torch._dynamo.export(
^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 1379, in inner
result_traced = opt_f(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 433, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/fx/graph_module.py", line 738, in call_wrapped
return self._wrapped_call(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/fx/graph_module.py", line 316, in __call__
raise e
File "/opt/conda/lib/python3.11/site-packages/torch/fx/graph_module.py", line 303, in __call__
return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 1116, in __call__
return self._torchdynamo_orig_callable(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 472, in __call__
return _compile(
^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/_utils_internal.py", line 84, in wrapper_function
return StrobelightCompileTimeProfiler.profile_compile_time(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/_strobelight/compile_time_profiler.py", line 129, in profile_compile_time
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/contextlib.py", line 81, in inner
return func(*args, **kwds)
^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 817, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 231, in time_wrapper
r = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 636, in compile_inner
out_code = transform_code_object(code, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/bytecode_transformation.py", line 1185, in transform_code_object
transformations(instructions, code_options)
File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 178, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 582, in transform
tracer.run()
File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2451, in run
super().run()
File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 893, in run
while self.step():
^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 808, in step
self.exception_handler()
File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1303, in exception_handler
raise Unsupported("Observed exception")
torch._dynamo.exc.Unsupported: Observed exception
from user code:
File "<eval_with_key>.0 from /larec/tzrec/tests/test_dy2.py:33 in forward", line 7, in forward
_get_dict = __main____get_dict(['query', 'sequence', 'sequence_length'], _args1); _args1 = None
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
I also tried the dynamic_shapes: https://pytorch.org/TensorRT/user_guide/dynamic_shapes.html
import torch
import torch_tensorrt
from typing import Optional, Sequence,Dict,List
from torch.nn import functional as F
from tzrec.modules.mlp import MLP
from torch import nn
@torch.fx.wrap
def _get_dict(grouped_features_keys: List[str], args:List[torch.Tensor])->Dict[str, torch.Tensor]:
if len(grouped_features_keys) != len(args):
raise ValueError(
"The number of grouped_features_keys must match "
"the number of arguments."
)
grouped_features = {
key: value for key, value in zip(grouped_features_keys, args)
}
return grouped_features
@torch.fx.wrap
def _arange(end: int, device: torch.device) -> torch.Tensor:
return torch.arange(end, device=device)
class MatMul(torch.nn.Module):
def __init__(self):
super().__init__()
self.keys = ["query","sequence","sequence_length"]
attn_mlp= {'hidden_units': [256, 64], 'dropout_ratio': [], 'activation': 'nn.ReLU', 'use_bn': False}
self.mlp = MLP(in_features=41 * 4, **attn_mlp)
self.linear = nn.Linear(self.mlp.hidden_units[-1], 1)
def forward(self, *args1: List[torch.Tensor]):
"""Forward the module."""
# use predict to avoid trace error in self._output_to_prediction(y)
return self.predict(args1)
def predict(self, args: List[torch.Tensor]):
grouped_features= _get_dict(self.keys, args)
query = grouped_features["query"]
sequence = grouped_features["sequence"]
sequence_length = grouped_features["sequence_length"]
max_seq_length = sequence.size(1)
sequence_mask = _arange(
max_seq_length, device=sequence_length.device
).unsqueeze(0) < sequence_length.unsqueeze(1)
queries = query.unsqueeze(1).expand(-1, max_seq_length, -1)
attn_input = torch.cat(
[queries, sequence, queries - sequence, queries * sequence], dim=-1
)
return attn_input
model = MatMul().eval().cuda()
a=torch.randn(2, 41).cuda()
b=torch.randn(2, 2,41).cuda()
c=torch.randn(2).cuda()
# torch._dynamo.mark_dynamic(a, 0,min=1,max=8196)
# torch._dynamo.mark_dynamic(b, 0,min=1,max=8196)
# # torch._dynamo.mark_dynamic(b, 1, min=1, max=50)
# torch._dynamo.mark_dynamic(c, 0,min=1,max=8196)
inputs = [a, b,c]
print(model(*inputs)[0][0][0])
batch = torch.export.Dim("batch",min=1,max=8196)
seq_len = torch.export.Dim("seq_len",min=1,max=50)
dynamic_shapes={"args1": ({0:batch},{0:batch,1:seq_len},{0:batch})}
# Export the model first with custom dynamic shape constraints
from torchrec.fx import symbolic_trace
model = symbolic_trace(model)
print(model.code)
exp_program = torch.export.export(model, (*inputs,),dynamic_shapes=dynamic_shapes)
trt_gm = torch_tensorrt.dynamo.compile(exp_program, inputs, assume_dynamic_shape_support=True,
allow_shape_tensors=True,min_block_size=2)
it has the same problem as the torch._dynamo.mark_dynamic(a, 0,min=1,max=8196)
@apbose can you help me?
Yeah sure, let me take a look and get back on this.
Hi @yjjinjie may I know where can I find tzrec? because it shows module not found tzrec
@apbose you can just delete tzrec and mlp code just like this :
import torch
import torch_tensorrt
from typing import Optional, Sequence,Dict,List
from torch.nn import functional as F
from torch import nn
@torch.fx.wrap
def _get_dict(grouped_features_keys: List[str], args:List[torch.Tensor])->Dict[str, torch.Tensor]:
if len(grouped_features_keys) != len(args):
raise ValueError(
"The number of grouped_features_keys must match "
"the number of arguments."
)
grouped_features = {
key: value for key, value in zip(grouped_features_keys, args)
}
return grouped_features
@torch.fx.wrap
def _arange(end: int, device: torch.device) -> torch.Tensor:
return torch.arange(end, device=device)
class MatMul(torch.nn.Module):
def __init__(self):
super().__init__()
self.keys = ["query","sequence","sequence_length"]
def forward(self, *args1: List[torch.Tensor]):
"""Forward the module."""
# use predict to avoid trace error in self._output_to_prediction(y)
return self.predict(args1)
def predict(self, args: List[torch.Tensor]):
grouped_features= _get_dict(self.keys, args)
query = grouped_features["query"]
sequence = grouped_features["sequence"]
sequence_length = grouped_features["sequence_length"]
max_seq_length = sequence.size(1)
sequence_mask = _arange(
max_seq_length, device=sequence_length.device
).unsqueeze(0) < sequence_length.unsqueeze(1)
queries = query.unsqueeze(1).expand(-1, max_seq_length, -1)
attn_input = torch.cat(
[queries, sequence, queries - sequence, queries * sequence], dim=-1
)
return attn_input
model = MatMul().eval().cuda()
a=torch.randn(2, 41).cuda()
b=torch.randn(2, 2,41).cuda()
c=torch.randn(2).cuda()
d=torch.tensor(2)
torch._dynamo.mark_dynamic(a, 0, min=1, max=8196)
torch._dynamo.mark_dynamic(b, 0, min=1, max=8196)
torch._dynamo.mark_dynamic(b, 1, min=1, max=50)
torch._dynamo.mark_dynamic(c, 0, min=1, max=8196)
# torch._dynamo.mark_dynamic(d, 0, min=1, max=8196)
inputs = [a, b,c]
print(model(*inputs)[0][0][0])
# seq_len = torch.export.Dim("seq_len", min=1, max=10)
# dynamic_shapes=({2: seq_len}, {2: seq_len})
# Export the model first with custom dynamic shape constraints
from torchrec.fx import symbolic_trace
model = symbolic_trace(model)
inputs_dy = []
inputs_dy.append(
torch_tensorrt.Input(
min_shape=[1, 41],
opt_shape=[512, 41],
max_shape=[8196, 41],
name="query",
)
)
inputs_dy.append(
torch_tensorrt.Input(
min_shape=[1, 1,41],
opt_shape=[512, 2, 41],
max_shape=[8196,50, 41],
name="sequence",
)
)
inputs_dy.append(
torch_tensorrt.Input(
min_shape=[1],
opt_shape=[512],
max_shape=[8196],
name="sequence_length",
)
)
trt_gm = torch_tensorrt.compile(
model,
ir="dynamo",
inputs=[*inputs],min_block_size=1,
torch_executed_ops=["aten.expand"],)
print(trt_gm)
# exp_program = torch.export.export(model, (*inputs,))
# trt_gm = torch_tensorrt.dynamo.compile(exp_program, inputs,assume_dynamic_shape_support=True,
# allow_shape_tensors=True,min_block_size=2)
# Run inference
print(trt_gm(*inputs)[0][0][0])
# trt_gm = symbolic_trace(trt_gm)
trt_gm = torch.jit.trace(trt_gm,
example_inputs=(a,b,c),
strict=False)
scripted_model = torch.jit.script(trt_gm)
scripted_model.save("./scripted_model_trt.pt")
model_gpu = torch.jit.load(
"./scripted_model_trt.pt", map_location="cuda:0"
)
from torch.profiler import ProfilerActivity, profile, record_function
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
record_shapes=True,
) as prof:
with record_function("model_inference"):
model_gpu(*inputs)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=100))
print("load:",model_gpu(*inputs)[0][0][0])
I do not get the above error when I run the above code. Are you running on the latest branch. I did a few modifications in the code though-
import torch
import torch_tensorrt
from typing import Optional, Sequence,Dict,List
from torch.nn import functional as F
from torch import nn
@torch.fx.wrap
def _get_dict(grouped_features_keys: List[str], args:List[torch.Tensor])->Dict[str, torch.Tensor]:
if len(grouped_features_keys) != len(args):
raise ValueError(
"The number of grouped_features_keys must match "
"the number of arguments."
)
grouped_features = {
key: value for key, value in zip(grouped_features_keys, args)
}
return grouped_features
@torch.fx.wrap
def _arange(end: int, device: torch.device) -> torch.Tensor:
return torch.arange(end, device=device)
class MatMul(torch.nn.Module):
def __init__(self):
super().__init__()
self.keys = ["query","sequence","sequence_length"]
#def forward(self, *args1: List[torch.Tensor]):
def forward(self, args0, args1, args2):
"""Forward the module."""
# use predict to avoid trace error in self._output_to_prediction(y)
#return self.predict(args1)
return self.predict(args0, args1, args2)
#def predict(self, args: List[torch.Tensor]):
def predict(self, args0, args1, args2):
#grouped_features= _get_dict(self.keys, args)
#query = grouped_features["query"]
#sequence = grouped_features["sequence"]
#sequence_length = grouped_features["sequence_length"]
query = args0
sequence = args1
sequence_length = args2
max_seq_length = sequence.size(1)
sequence_mask = _arange(
max_seq_length, device=sequence_length.device
).unsqueeze(0) < sequence_length.unsqueeze(1)
queries = query.unsqueeze(1).expand(-1, max_seq_length, -1)
attn_input = torch.cat(
[queries, sequence, queries - sequence, queries * sequence], dim=-1
)
return attn_input
model = MatMul().eval().cuda()
a=torch.randn(2, 41).cuda()
b=torch.randn(2, 2,41).cuda()
c=torch.randn(2).cuda()
d=torch.tensor(2)
torch._dynamo.mark_dynamic(a, 0, min=1, max=8196)
torch._dynamo.mark_dynamic(b, 0, min=1, max=8196)
torch._dynamo.mark_dynamic(b, 1, min=1, max=50)
torch._dynamo.mark_dynamic(c, 0, min=1, max=8196)
# torch._dynamo.mark_dynamic(d, 0, min=1, max=8196)
inputs = [a, b,c]
print(model(*inputs)[0][0][0])
# seq_len = torch.export.Dim("seq_len", min=1, max=10)
# dynamic_shapes=({2: seq_len}, {2: seq_len})
# Export the model first with custom dynamic shape constraints
from torch.fx import symbolic_trace
model = symbolic_trace(model)
inputs_dy = []
compile_spec = {
"inputs": [
torch_tensorrt.Input(
min_shape=(1, 41),
opt_shape=(512, 41),
max_shape=(8196, 41),
),
torch_tensorrt.Input(
min_shape=(1, 1,41),
opt_shape=(512, 2, 41),
max_shape=(8196,50, 41),
),
torch_tensorrt.Input(
min_shape=(1,),
opt_shape=(512,),
max_shape=(8196,),
)
],
"enabled_precisions": {torch.half},
"ir": "dynamo",
}
inputs_dy.append(
torch_tensorrt.Input(
min_shape=(1, 41),
opt_shape=(512, 41),
max_shape=(8196, 41),
name="query",
)
)
inputs_dy.append(
torch_tensorrt.Input(
min_shape=(1, 1,41),
opt_shape=(512, 2, 41),
max_shape=(8196,50, 41),
name="sequence",
)
)
inputs_dy.append(
torch_tensorrt.Input(
min_shape=(1,),
opt_shape=(512,),
max_shape=(8196,),
name="sequence_length",
)
)
print("the inputs_dy is!!!", inputs_dy)
print("the star inputs_dy", *inputs_dy)
trt_gm = torch_tensorrt.compile(
model,
**compile_spec, min_block_size=1,
torch_executed_ops=["aten.expand"],
cache_built_engines = False,
reuse_cached_engines = False)
print(trt_gm)
# exp_program = torch.export.export(model, (*inputs,))
# trt_gm = torch_tensorrt.dynamo.compile(exp_program, inputs,assume_dynamic_shape_support=True,
# allow_shape_tensors=True,min_block_size=2)
# Run inference
print(trt_gm(*inputs)[0][0][0])
# trt_gm = symbolic_trace(trt_gm)
trt_gm = torch.jit.trace(trt_gm,
example_inputs=(a,b,c),
strict=False)
scripted_model = torch.jit.script(trt_gm)
scripted_model.save("./scripted_model_trt.pt")
model_gpu = torch.jit.load(
"./scripted_model_trt.pt", map_location="cuda:0"
)
from torch.profiler import ProfilerActivity, profile, record_function
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
record_shapes=True,
) as prof:
with record_function("model_inference"):
model_gpu(*inputs)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=100))
print("load:",model_gpu(*inputs)[0][0][0])
@apbose I use the torch_tensorrt 2.4.0, and use your code, it also has the same error. your torch_tensorrt version is?
my env is:
CPU(s): 104
On-line CPU(s) list: 0-103
Vendor ID: GenuineIntel
Model name: Intel(R) Xeon(R) Platinum 8269CY CPU @ 2.50GHz
CPU family: 6
Model: 85
Thread(s) per core: 2
Core(s) per socket: 26
Socket(s): 2
Stepping: 7
CPU max MHz: 3800.0000
CPU min MHz: 1200.0000
BogoMIPS: 5000.00
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cdp_l3 invpcid_single intel_ppin ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm mpx rdt_a avx512f avx512dq rdseed adx smap clflushopt clwb intel_pt avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local dtherm ida arat pln pts pku ospke avx512_vnni md_clear flush_l1d arch_capabilities
Virtualization: VT-x
L1d cache: 1.6 MiB (52 instances)
L1i cache: 1.6 MiB (52 instances)
L2 cache: 52 MiB (52 instances)
L3 cache: 71.5 MiB (2 instances)
NUMA node(s): 1
NUMA node0 CPU(s): 0-103
Vulnerability Itlb multihit: KVM: Mitigation: Split huge pages
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced IBRS, IBPB conditional, RSB filling
Vulnerability Tsx async abort: Mitigation; TSX disabled
Versions of relevant libraries:
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.26.4
[pip3] optree==0.12.1
[pip3] torch==2.4.0
[pip3] torch_tensorrt==2.4.0
[pip3] torchaudio==2.4.0
[pip3] torchelastic==0.2.2
[pip3] torchmetrics==1.0.3
[pip3] torchrec==0.8.0+cu121
[pip3] torchvision==0.19.0
[pip3] triton==3.0.0
[conda] blas 1.0 mkl
[conda] ffmpeg 4.3 hf484d3e_0 pytorch
[conda] libjpeg-turbo 2.0.0 h9bf148f_0 pytorch
[conda] mkl 2023.1.0 h213fc3f_46344
[conda] mkl-service 2.4.0 py311h5eee18b_1
[conda] mkl_fft 1.3.8 py311h5eee18b_0
[conda] mkl_random 1.2.4 py311hdb19cb5_0
[conda] numpy 1.26.4 py311h08b1b3b_0
[conda] numpy-base 1.26.4 py311hf175353_0
[conda] optree 0.12.1 pypi_0 pypi
[conda] pytorch 2.4.0 py3.11_cuda12.1_cudnn9.1.0_0 pytorch
[conda] pytorch-cuda 12.1 ha16c6d3_5 pytorch
[conda] pytorch-mutex 1.0 cuda pytorch
[conda] torch-tensorrt 2.4.0 pypi_0 pypi
[conda] torchaudio 2.4.0 py311_cu121 pytorch
[conda] torchelastic 0.2.2 pypi_0 pypi
[conda] torchmetrics 1.0.3 pypi_0 pypi
[conda] torchrec 0.8.0+cu121 pypi_0 pypi
[conda] torchtriton 3.0.0 py311 pytorch
[conda] torchvision 0.19.0 py311_cu121 pytorch
@apbose I use pip install --pre torch-tensorrt --index-url https://download.pytorch.org/whl/nightly/cu124 to install torch_tensorrt 2.5.0.dev20240822+cu124
then your code is correct, when do you release 2.5.0?
I cannot install pip install https://download.pytorch.org/whl/nightly/cu124/torch-2.6.0.dev20241013%2Bcu124-cp311-cp311-linux_x86_64.whl, becase of the error:
ERROR: Could not find a version that satisfies the requirement pytorch-triton==3.1.0+cf34004b8a; platform_system == "Linux" and platform_machine == "x86_64" and python_version < "3.13" (from torch) (from versions: 0.0.1)
ERROR: No matching distribution found for pytorch-triton==3.1.0+cf34004b8a; platform_system == "Linux" and platform_machine == "x86_64" and python_version < "3.13"
@apbose in my real code , it has another error:
when I use thetorch_tensorrt 2.5.0.dev20240822+cu124 , the dynamic shape: I find it encounted error when run the padding = torch.ones_like(attn_output) * (-(2**32) + 1). the static shape is also errror: when run the scores = torch.where(sequence_mask.unsqueeze(1), attn_output, padding)
when I use torch_tensorrt 2.4.0; the dynamic shape error is the original error the static shape is correct
dynamic the error is:
Traceback (most recent call last):
File "/opt/conda/lib/python3.11/site-packages/torch/fx/graph_module.py", line 303, in __call__
return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "<eval_with_key>.41", line 37, in forward
full_default = torch.ops.aten.full.default([s0, 1, s2], 1, pin_memory = False, device = device(type='cuda', index=0), dtype = torch.float32)
^^
NameError: name 's0' is not defined
the code is :
import torch
import torch_tensorrt
from typing import Optional, Sequence,Dict,List
from torch.nn import functional as F
from torch import nn
@torch.fx.wrap
def _get_dict(grouped_features_keys: List[str], args:List[torch.Tensor])->Dict[str, torch.Tensor]:
if len(grouped_features_keys) != len(args):
raise ValueError(
"The number of grouped_features_keys must match "
"the number of arguments."
)
grouped_features = {
key: value for key, value in zip(grouped_features_keys, args)
}
return grouped_features
@torch.fx.wrap
def _arange(end: int, device: torch.device) -> torch.Tensor:
return torch.arange(end, device=device)
class MatMul(torch.nn.Module):
def __init__(self):
super().__init__()
self.keys = ["query","sequence","sequence_length"]
attn_mlp= {'hidden_units': [256, 64], 'dropout_ratio': [], 'activation': 'nn.ReLU', 'use_bn': False}
self.linear1 = nn.Linear(41*4, 256)
self.linear2 = nn.Linear(256, 64)
self.linear = nn.Linear(64, 1)
#def forward(self, *args1: List[torch.Tensor]):
def forward(self, args0, args1, args2):
"""Forward the module."""
# use predict to avoid trace error in self._output_to_prediction(y)
#return self.predict(args1)
return self.predict(args0, args1, args2)
#def predict(self, args: List[torch.Tensor]):
def predict(self, args0, args1, args2):
#grouped_features= _get_dict(self.keys, args)
#query = grouped_features["query"]
#sequence = grouped_features["sequence"]
#sequence_length = grouped_features["sequence_length"]
query = args0
sequence = args1
sequence_length = args2
max_seq_length = sequence.size(1)
sequence_mask = _arange(
max_seq_length, device=sequence_length.device
).unsqueeze(0) < sequence_length.unsqueeze(1)
queries = query.unsqueeze(1).expand(-1, max_seq_length, -1)
attn_input = torch.cat(
[queries, sequence, queries - sequence, queries * sequence], dim=-1
)
#attn_output = self.mlp(attn_input)
attn_output = self.linear1(attn_input)
attn_output = self.linear2(attn_output)
print(attn_output.shape)
attn_output = self.linear(attn_output)
attn_output = attn_output.transpose(1, 2)
padding = torch.ones_like(attn_output) * (-(2**32) + 1)
scores = torch.where(sequence_mask.unsqueeze(1), attn_output, padding)
scores = F.softmax(scores, dim=-1)
return torch.matmul(scores, sequence).squeeze(1)
#return padding
#return attn_input
model = MatMul().eval().cuda()
a=torch.randn(2, 41).cuda()
b=torch.randn(2, 2,41).cuda()
c=torch.randn(2).cuda()
d=torch.tensor(2)
torch._dynamo.mark_dynamic(a, 0, min=1, max=8196)
torch._dynamo.mark_dynamic(b, 0, min=1, max=8196)
torch._dynamo.mark_dynamic(b, 1, min=1, max=50)
torch._dynamo.mark_dynamic(c, 0, min=1, max=8196)
# torch._dynamo.mark_dynamic(d, 0, min=1, max=8196)
inputs = [a, b,c]
print(model(*inputs))
# seq_len = torch.export.Dim("seq_len", min=1, max=10)
# dynamic_shapes=({2: seq_len}, {2: seq_len})
# Export the model first with custom dynamic shape constraints
from torch.fx import symbolic_trace
model = symbolic_trace(model)
inputs_dy = []
compile_spec = {
"inputs": [
torch_tensorrt.Input(
min_shape=(1, 41),
opt_shape=(512, 41),
max_shape=(8196, 41),
),
torch_tensorrt.Input(
min_shape=(1, 1,41),
opt_shape=(512, 2, 41),
max_shape=(8196,50, 41),
),
torch_tensorrt.Input(
min_shape=(1,),
opt_shape=(512,),
max_shape=(8196,),
)
],
"enabled_precisions": {torch.half},
"ir": "dynamo",
}
inputs_dy.append(
torch_tensorrt.Input(
min_shape=(1, 41),
opt_shape=(512, 41),
max_shape=(8196, 41),
name="query",
)
)
inputs_dy.append(
torch_tensorrt.Input(
min_shape=(1, 1,41),
opt_shape=(512, 2, 41),
max_shape=(8196,50, 41),
name="sequence",
)
)
inputs_dy.append(
torch_tensorrt.Input(
min_shape=(1,),
opt_shape=(512,),
max_shape=(8196,),
name="sequence_length",
)
)
print("the inputs_dy is!!!", inputs_dy)
print("the star inputs_dy", *inputs_dy)
with torch_tensorrt.logging.graphs():
trt_gm = torch_tensorrt.compile(
model,
**compile_spec, min_block_size=1,
cache_built_engines = False,
reuse_cached_engines = False)
print(trt_gm)
# exp_program = torch.export.export(model, (*inputs,))
# trt_gm = torch_tensorrt.dynamo.compile(exp_program, inputs,assume_dynamic_shape_support=True,
# allow_shape_tensors=True,min_block_size=2)
# Run inference
print(trt_gm(*inputs))
# trt_gm = symbolic_trace(trt_gm)
trt_gm = torch.jit.trace(trt_gm,
example_inputs=(a,b,c),
strict=False)
scripted_model = torch.jit.script(trt_gm)
scripted_model.save("./scripted_model_trt.pt")
model_gpu = torch.jit.load(
"./scripted_model_trt.pt", map_location="cuda:0"
)
from torch.profiler import ProfilerActivity, profile, record_function
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
record_shapes=True,
) as prof:
with record_function("model_inference"):
model_gpu(*inputs)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=100))
print("load:",model_gpu(*inputs))
can you help me solve this problem @apbose
when I use the nvcr.io/nvidia/pytorch:24.09-py3, then the code is ok.
torch 2.5.0a0+b465a5843b.nv24.9
torch_tensorrt 2.5.0a0
2.5.0a0 is which day of torch_tensorrt?
but the docker image system is incompatible with my project, when to release the new version 2.5.0?
Hi @yjjinjie you can find the release wheels here- https://[download.pytorch.org/whl/test/torch-tensorrt/](https://download.pytorch.org/whl/test/torch-tensorrt/). The torchTRT 2.5 release artifacts got pushed in officially yesterday. As such if you want to work with the recent torchTRT changes which is torchTRT 2.6 (you can find the release version here- https://github.com/pytorch/TensorRT/blob/main/version.txt), you can work with the docker image - ghcr.io/pytorch/tensorrt/torch_tensorrt:nightly
@apbose hello,when i install torch_tensorrt==2.5.0, it also has error
Traceback (most recent call last):
File "/opt/conda/lib/python3.11/site-packages/torch/fx/graph_module.py", line 303, in __call__
return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "<eval_with_key>.41", line 37, in forward
full_default = torch.ops.aten.full.default([s0, 1, s2], 1, pin_memory = False, device = device(type='cuda', index=0), dtype = torch.float32)
^^
NameError: name 's0' is not defined
when I use the nvcr.io/nvidia/pytorch:24.09-py3, then the code is ok.
torch 2.5.0a0+b465a5843b.nv24.9 torch_tensorrt 2.5.0a0
2.5.0a0 is which day of torch_tensorrt? can you update the version of 2.5.0? because I want to install torch_tensorrt in my project
Can you try with a new virtual env and install torch tensorrt from here- https://[download.pytorch.org/whl/test/torch-tensorrt/](https://download.pytorch.org/whl/test/torch-tensorrt/) the wheel torch_tensorrt-2.5.0+cu124-cp310-cp310-linux_x86_64.whl. This will have torch-tensorrt 2.5 and torch 2.5. And let me know what the error is?
@apbose I new a new virtual env ,and install torch_tensorrt-2.5.0+cu124-cp310-cp310-linux_x86_64.whl. it has same error .
only run:
conda create -n trt python=3.10
conda activate trt
pip install torch_tensorrt-2.5.0+cu124-cp310-cp310-linux_x86_64.whl
and run collect_env:
wget https://raw.githubusercontent.com/pytorch/pytorch/main/torch/utils/collect_env.py
python collect_env.py
the result:
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
CPU(s): 104
On-line CPU(s) list: 0-103
Thread(s) per core: 2
Core(s) per socket: 26
Socket(s): 2
NUMA node(s): 1
Vendor ID: GenuineIntel
CPU family: 6
Model: 85
Model name: Intel(R) Xeon(R) Platinum 8269CY CPU @ 2.50GHz
Stepping: 7
CPU MHz: 2500.019
CPU max MHz: 3800.0000
CPU min MHz: 1200.0000
BogoMIPS: 5000.00
Virtualization: VT-x
L1d cache: 32K
L1i cache: 32K
L2 cache: 1024K
L3 cache: 36608K
NUMA node0 CPU(s): 0-103
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cdp_l3 invpcid_single intel_ppin ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid cqm mpx rdt_a avx512f avx512dq rdseed adx smap clflushopt clwb intel_pt avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local dtherm ida arat pln pts pku ospke avx512_vnni md_clear flush_l1d arch_capabilities
Versions of relevant libraries:
[pip3] numpy==2.1.2
[pip3] nvidia-cublas-cu12==12.4.5.8
[pip3] nvidia-cuda-cupti-cu12==12.4.127
[pip3] nvidia-cuda-nvrtc-cu12==12.4.127
[pip3] nvidia-cuda-runtime-cu12==12.4.127
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-cufft-cu12==11.2.1.3
[pip3] nvidia-curand-cu12==10.3.5.147
[pip3] nvidia-cusolver-cu12==11.6.1.9
[pip3] nvidia-cusparse-cu12==12.3.1.170
[pip3] nvidia-nccl-cu12==2.21.5
[pip3] nvidia-nvjitlink-cu12==12.4.127
[pip3] nvidia-nvtx-cu12==12.4.127
[pip3] torch==2.5.0
[pip3] torch_tensorrt==2.5.0+cu124
[pip3] triton==3.1.0
[conda] numpy 2.1.2 pypi_0 pypi
[conda] nvidia-cublas-cu12 12.4.5.8 pypi_0 pypi
[conda] nvidia-cuda-cupti-cu12 12.4.127 pypi_0 pypi
[conda] nvidia-cuda-nvrtc-cu12 12.4.127 pypi_0 pypi
[conda] nvidia-cuda-runtime-cu12 12.6.77 pypi_0 pypi
[conda] nvidia-cudnn-cu12 9.1.0.70 pypi_0 pypi
[conda] nvidia-cufft-cu12 11.2.1.3 pypi_0 pypi
[conda] nvidia-curand-cu12 10.3.5.147 pypi_0 pypi
[conda] nvidia-cusolver-cu12 11.6.1.9 pypi_0 pypi
[conda] nvidia-cusparse-cu12 12.3.1.170 pypi_0 pypi
[conda] nvidia-nccl-cu12 2.21.5 pypi_0 pypi
[conda] nvidia-nvjitlink-cu12 12.4.127 pypi_0 pypi
[conda] nvidia-nvtx-cu12 12.4.127 pypi_0 pypi
[conda] torch 2.5.0 pypi_0 pypi
[conda] torch-tensorrt 2.5.0+cu124 pypi_0 pypi
[conda] triton 3.1.0 pypi_0 pypi
the code is:
import torch
import torch_tensorrt
from typing import Optional, Sequence,Dict,List
from torch.nn import functional as F
from torch import nn
@torch.fx.wrap
def _get_dict(grouped_features_keys: List[str], args:List[torch.Tensor])->Dict[str, torch.Tensor]:
if len(grouped_features_keys) != len(args):
raise ValueError(
"The number of grouped_features_keys must match "
"the number of arguments."
)
grouped_features = {
key: value for key, value in zip(grouped_features_keys, args)
}
return grouped_features
@torch.fx.wrap
def _arange(end: int, device: torch.device) -> torch.Tensor:
return torch.arange(end, device=device)
class MatMul(torch.nn.Module):
def __init__(self):
super().__init__()
self.keys = ["query","sequence","sequence_length"]
attn_mlp= {'hidden_units': [256, 64], 'dropout_ratio': [], 'activation': 'nn.ReLU', 'use_bn': False}
self.linear1 = nn.Linear(41*4, 256)
self.linear2 = nn.Linear(256, 64)
self.linear = nn.Linear(64, 1)
#def forward(self, *args1: List[torch.Tensor]):
def forward(self, args0, args1, args2):
"""Forward the module."""
# use predict to avoid trace error in self._output_to_prediction(y)
#return self.predict(args1)
return self.predict(args0, args1, args2)
#def predict(self, args: List[torch.Tensor]):
def predict(self, args0, args1, args2):
#grouped_features= _get_dict(self.keys, args)
#query = grouped_features["query"]
#sequence = grouped_features["sequence"]
#sequence_length = grouped_features["sequence_length"]
query = args0
sequence = args1
sequence_length = args2
max_seq_length = sequence.size(1)
sequence_mask = _arange(
max_seq_length, device=sequence_length.device
).unsqueeze(0) < sequence_length.unsqueeze(1)
queries = query.unsqueeze(1).expand(-1, max_seq_length, -1)
attn_input = torch.cat(
[queries, sequence, queries - sequence, queries * sequence], dim=-1
)
#attn_output = self.mlp(attn_input)
attn_output = self.linear1(attn_input)
attn_output = self.linear2(attn_output)
print(attn_output.shape)
attn_output = self.linear(attn_output)
attn_output = attn_output.transpose(1, 2)
padding = torch.ones_like(attn_output) * (-(2**32) + 1)
scores = torch.where(sequence_mask.unsqueeze(1), attn_output, padding)
scores = F.softmax(scores, dim=-1)
return torch.matmul(scores, sequence).squeeze(1)
#return padding
return attn_output
model = MatMul().eval().cuda()
a=torch.randn(2, 41).cuda()
b=torch.randn(2, 2,41).cuda()
c=torch.randn(2).cuda()
d=torch.tensor(2)
torch._dynamo.mark_dynamic(a, 0, min=1, max=8196)
torch._dynamo.mark_dynamic(b, 0, min=1, max=8196)
torch._dynamo.mark_dynamic(b, 1, min=1, max=50)
torch._dynamo.mark_dynamic(c, 0, min=1, max=8196)
# torch._dynamo.mark_dynamic(d, 0, min=1, max=8196)
inputs = [a, b,c]
print(model(*inputs))
# seq_len = torch.export.Dim("seq_len", min=1, max=10)
# dynamic_shapes=({2: seq_len}, {2: seq_len})
# Export the model first with custom dynamic shape constraints
from torch.fx import symbolic_trace
model = symbolic_trace(model)
inputs_dy = []
compile_spec = {
"inputs": [
torch_tensorrt.Input(
min_shape=(1, 41),
opt_shape=(512, 41),
max_shape=(8196, 41),
),
torch_tensorrt.Input(
min_shape=(1, 1,41),
opt_shape=(512, 2, 41),
max_shape=(8196,50, 41),
),
torch_tensorrt.Input(
min_shape=(1,),
opt_shape=(512,),
max_shape=(8196,),
)
],
"enabled_precisions": {torch.half},
"ir": "dynamo",
}
inputs_dy.append(
torch_tensorrt.Input(
min_shape=(1, 41),
opt_shape=(512, 41),
max_shape=(8196, 41),
name="query",
)
)
inputs_dy.append(
torch_tensorrt.Input(
min_shape=(1, 1,41),
opt_shape=(512, 2, 41),
max_shape=(8196,50, 41),
name="sequence",
)
)
inputs_dy.append(
torch_tensorrt.Input(
min_shape=(1,),
opt_shape=(512,),
max_shape=(8196,),
name="sequence_length",
)
)
print("the inputs_dy is!!!", inputs_dy)
print("the star inputs_dy", *inputs_dy)
with torch_tensorrt.logging.graphs():
trt_gm = torch_tensorrt.compile(
model,
**compile_spec, min_block_size=1,
cache_built_engines = False,
reuse_cached_engines = False)
print(trt_gm)
# exp_program = torch.export.export(model, (*inputs,))
# trt_gm = torch_tensorrt.dynamo.compile(exp_program, inputs,assume_dynamic_shape_support=True,
# allow_shape_tensors=True,min_block_size=2)
# Run inference
print(trt_gm(*inputs))
# trt_gm = symbolic_trace(trt_gm)
trt_gm = torch.jit.trace(trt_gm,
example_inputs=(a,b,c),
strict=False)
scripted_model = torch.jit.script(trt_gm)
scripted_model.save("./scripted_model_trt.pt")
model_gpu = torch.jit.load(
"./scripted_model_trt.pt", map_location="cuda:0"
)
from torch.profiler import ProfilerActivity, profile, record_function
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
record_shapes=True,
) as prof:
with record_function("model_inference"):
model_gpu(*inputs)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=100))
print("load:",model_gpu(*inputs))
the error:
File "<eval_with_key>.43", line 33, in forward
full_default = torch.ops.aten.full.default([s0, 1, s2], 1, pin_memory = False, device = device(type='cuda', index=0), dtype = torch.float32)
NameError: name 's0' is not defined
Call using an FX-traced Module, line 33 of the traced Module's generated forward function:
permute_3 = torch.ops.aten.permute.default(reshape_default_5, [0, 2, 1]); reshape_default_5 = None
full_default = torch.ops.aten.full.default([s0, 1, s2], 1, pin_memory = False, device = device(type='cuda', index=0), dtype = torch.float32)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
mul_59 = torch.ops.aten.mul.Tensor(full_default, -4294967295); full_default = None
unsqueeze_3 = torch.ops.aten.unsqueeze.default(lt, 1); lt = None
@apbose can you help me solve this problem?
Yes taking a look.
On Wed, Oct 23, 2024, 7:31 PM yjjinjie @.***> wrote:
@apbose https://github.com/apbose can you help me solve this problem?
— Reply to this email directly, view it on GitHub https://github.com/pytorch/TensorRT/issues/3140#issuecomment-2434108417, or unsubscribe https://github.com/notifications/unsubscribe-auth/AKRJMR3R6TP5KREA3SVZOALZ5BLXRAVCNFSM6AAAAABNROACW2VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDIMZUGEYDQNBRG4 . You are receiving this because you were mentioned.Message ID: @.***>
I did not get a chance to look at this one yet, but let me get back to you soon regarding this
I could repro the error-
File "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 487, in call_function
return converter(self.ctx, target, args, kwargs, self._cur_node_name)
File "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch_tensorrt/dynamo/conversion/aten_ops_converters.py", line 1937, in aten_ops_sub
return impl.elementwise.sub(
File "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py", line 492, in sub
return convert_binary_elementwise(
File "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py", line 154, in convert_binary_elementwise
lhs_val, rhs_val = broadcast(
File "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch_tensorrt/fx/converters/converter_utils.py", line 404, in broadcast
a_shape = tuple(a.shape)
ValueError: __len__() should return >= 0
on torchTRT2.4. I am yet to try on torchTRT2.5 and torchTRT2.6. Will try that and update here. Wanted to know do you see the same as above or in torchTRT2.5, the error is different, which is below
File "<eval_with_key>.43", line 33, in forward
full_default = torch.ops.aten.full.default([s0, 1, s2], 1, pin_memory = False, device = device(type='cuda', index=0), dtype = torch.float32)
NameError: name 's0' is not defined
Call using an FX-traced Module, line 33 of the traced Module's generated forward function:
permute_3 = torch.ops.aten.permute.default(reshape_default_5, [0, 2, 1]); reshape_default_5 = None
full_default = torch.ops.aten.full.default([s0, 1, s2], 1, pin_memory = False, device = device(type='cuda', index=0), dtype = torch.float32)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
mul_59 = torch.ops.aten.mul.Tensor(full_default, -4294967295); full_default = None
unsqueeze_3 = torch.ops.aten.unsqueeze.default(lt, 1); lt = None
yes.
in torchTRT2.4, it has the error: ValueError: len() should return >= 0
in torchTrt2.5 release , it has the error: NameError: name 's0' is not defined
Hmm so the thing is in torchTRT2.5 docker container I see it passing. It is failing in 2.4 with the error ValueError: __len__() should return >= 0
. This is the output I get in 2.5 container
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls
model_inference 0.00% 0.000us 0.00% 0.000us 0.000us 13.056us 129.94% 13.056us 13.056us 1
model_inference 27.17% 686.966us 98.53% 2.491ms 2.491ms 0.000us 0.00% 10.048us 10.048us 1
forward 0.82% 20.720us 71.36% 1.804ms 1.804ms 0.000us 0.00% 10.048us 10.048us 1
tensorrt::execute_engine 6.07% 153.379us 70.54% 1.783ms 1.783ms 10.048us 100.00% 10.048us 10.048us 1
generatedNativePointwise 0.00% 0.000us 0.00% 0.000us 0.000us 3.808us 37.90% 3.808us 1.904us 2
void genericReformat::copyPackedKernel<float, float,... 0.00% 0.000us 0.00% 0.000us 0.000us 3.680us 36.62% 3.680us 1.840us 2 void cuSliceLayer::naiveSlice<int, (cuSliceLayer::Mo... 0.00% 0.000us 0.00% 0.000us 0.000us 2.560us 25.48% 2.560us 2.560us 1 aten::view 0.55% 14.019us 0.55% 14.019us 7.010us 0.000us 0.00% 0.000us 0.000us 2 aten::empty 61.55% 1.556ms 61.55% 1.556ms 1.556ms 0.000us 0.00% 0.000us 0.000us 1 aten::to 0.04% 1.080us 0.04% 1.080us 1.080us 0.000us 0.00% 0.000us 0.000us 1 cudaEventRecord 0.27% 6.820us 0.27% 6.820us 3.410us 0.000us 0.00% 0.000us 0.000us 2 cudaStreamWaitEvent 0.16% 4.040us 0.16% 4.040us 2.020us 0.000us 0.00% 0.000us 0.000us 2 cudaLaunchKernel 1.49% 37.730us 1.49% 37.730us 12.577us 0.000us 0.00% 0.000us 0.000us 3 cuLaunchKernel 0.41% 10.300us 0.41% 10.300us 5.150us 0.000us 0.00% 0.000us 0.000us 2 cudaDeviceSynchronize 1.47% 37.210us 1.47% 37.210us 37.210us 0.000us 0.00% 0.000us 0.000us 1
Self CPU time total: 2.528ms Self CUDA time total: 10.048us
load: tensor(0.4938, device='cuda:0')
Can you please try one more thing, can you use this docker container docker pull ghcr.io/pytorch/tensorrt/torch_tensorrt:release_2.5
@apbose hello,I use the image, docker pull ghcr.io/pytorch/tensorrt/torch_tensorrt:release_2.5, it has the same error
please use the below code, your code may be not same with me,because my new code output is multi-demension.
import torch
import torch_tensorrt
from typing import Optional, Sequence,Dict,List
from torch.nn import functional as F
from torch import nn
@torch.fx.wrap
def _get_dict(grouped_features_keys: List[str], args:List[torch.Tensor])->Dict[str, torch.Tensor]:
if len(grouped_features_keys) != len(args):
raise ValueError(
"The number of grouped_features_keys must match "
"the number of arguments."
)
grouped_features = {
key: value for key, value in zip(grouped_features_keys, args)
}
return grouped_features
@torch.fx.wrap
def _arange(end: int, device: torch.device) -> torch.Tensor:
return torch.arange(end, device=device)
class MatMul(torch.nn.Module):
def __init__(self):
super().__init__()
self.keys = ["query","sequence","sequence_length"]
attn_mlp= {'hidden_units': [256, 64], 'dropout_ratio': [], 'activation': 'nn.ReLU', 'use_bn': False}
self.linear1 = nn.Linear(41*4, 256)
self.linear2 = nn.Linear(256, 64)
self.linear = nn.Linear(64, 1)
#def forward(self, *args1: List[torch.Tensor]):
def forward(self, args0, args1, args2):
"""Forward the module."""
# use predict to avoid trace error in self._output_to_prediction(y)
#return self.predict(args1)
return self.predict(args0, args1, args2)
#def predict(self, args: List[torch.Tensor]):
def predict(self, args0, args1, args2):
#grouped_features= _get_dict(self.keys, args)
#query = grouped_features["query"]
#sequence = grouped_features["sequence"]
#sequence_length = grouped_features["sequence_length"]
query = args0
sequence = args1
sequence_length = args2
max_seq_length = sequence.size(1)
sequence_mask = _arange(
max_seq_length, device=sequence_length.device
).unsqueeze(0) < sequence_length.unsqueeze(1)
queries = query.unsqueeze(1).expand(-1, max_seq_length, -1)
attn_input = torch.cat(
[queries, sequence, queries - sequence, queries * sequence], dim=-1
)
#attn_output = self.mlp(attn_input)
attn_output = self.linear1(attn_input)
attn_output = self.linear2(attn_output)
print(attn_output.shape)
attn_output = self.linear(attn_output)
attn_output = attn_output.transpose(1, 2)
padding = torch.ones_like(attn_output) * (-(2**32) + 1)
scores = torch.where(sequence_mask.unsqueeze(1), attn_output, padding)
scores = F.softmax(scores, dim=-1)
return torch.matmul(scores, sequence).squeeze(1)
#return padding
return attn_output
model = MatMul().eval().cuda()
a=torch.randn(2, 41).cuda()
b=torch.randn(2, 2,41).cuda()
c=torch.randn(2).cuda()
d=torch.tensor(2)
torch._dynamo.mark_dynamic(a, 0, min=1, max=8196)
torch._dynamo.mark_dynamic(b, 0, min=1, max=8196)
torch._dynamo.mark_dynamic(b, 1, min=1, max=50)
torch._dynamo.mark_dynamic(c, 0, min=1, max=8196)
# torch._dynamo.mark_dynamic(d, 0, min=1, max=8196)
inputs = [a, b,c]
print(model(*inputs))
# seq_len = torch.export.Dim("seq_len", min=1, max=10)
# dynamic_shapes=({2: seq_len}, {2: seq_len})
# Export the model first with custom dynamic shape constraints
from torch.fx import symbolic_trace
model = symbolic_trace(model)
inputs_dy = []
compile_spec = {
"inputs": [
torch_tensorrt.Input(
min_shape=(1, 41),
opt_shape=(512, 41),
max_shape=(8196, 41),
),
torch_tensorrt.Input(
min_shape=(1, 1,41),
opt_shape=(512, 2, 41),
max_shape=(8196,50, 41),
),
torch_tensorrt.Input(
min_shape=(1,),
opt_shape=(512,),
max_shape=(8196,),
)
],
"enabled_precisions": {torch.half},
"ir": "dynamo",
}
inputs_dy.append(
torch_tensorrt.Input(
min_shape=(1, 41),
opt_shape=(512, 41),
max_shape=(8196, 41),
name="query",
)
)
inputs_dy.append(
torch_tensorrt.Input(
min_shape=(1, 1,41),
opt_shape=(512, 2, 41),
max_shape=(8196,50, 41),
name="sequence",
)
)
inputs_dy.append(
torch_tensorrt.Input(
min_shape=(1,),
opt_shape=(512,),
max_shape=(8196,),
name="sequence_length",
)
)
print("the inputs_dy is!!!", inputs_dy)
print("the star inputs_dy", *inputs_dy)
with torch_tensorrt.logging.graphs():
trt_gm = torch_tensorrt.compile(
model,
**compile_spec, min_block_size=1,
cache_built_engines = False,
reuse_cached_engines = False)
print(trt_gm)
# exp_program = torch.export.export(model, (*inputs,))
# trt_gm = torch_tensorrt.dynamo.compile(exp_program, inputs,assume_dynamic_shape_support=True,
# allow_shape_tensors=True,min_block_size=2)
# Run inference
print(trt_gm(*inputs))
# trt_gm = symbolic_trace(trt_gm)
trt_gm = torch.jit.trace(trt_gm,
example_inputs=(a,b,c),
strict=False)
scripted_model = torch.jit.script(trt_gm)
scripted_model.save("./scripted_model_trt.pt")
model_gpu = torch.jit.load(
"./scripted_model_trt.pt", map_location="cuda:0"
)
from torch.profiler import ProfilerActivity, profile, record_function
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
record_shapes=True,
) as prof:
with record_function("model_inference"):
model_gpu(*inputs)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=100))
print("load:",model_gpu(*inputs))
the error:
full_default = torch.ops.aten.full.default([s0, 1, s2], 1, pin_memory = False, device = device(type='cuda', index=0), dtype = torch.float32)
NameError: name 's0' is not defined
when I use the nvcr.io/nvidia/pytorch:24.09-py3 ,the code is correct, the output is
load: tensor([[-1.1872, -0.5101, 1.9891, 1.7680, 1.4139, -0.2162, 1.2833, -0.9097,
0.6203, 0.5390, -0.2642, 1.7545, -1.2082, -0.7723, -0.3190, 0.1017,
0.4799, 0.2186, -0.3029, -1.4194, 3.3411, -0.2459, -0.3860, 1.7662,
1.3203, -0.4731, -0.3768, -0.7993, 0.1499, -1.2849, 1.3602, 0.0561,
-0.8575, 0.1106, 1.5936, -0.5553, -0.0827, -1.0445, 0.0348, -1.1662,
-1.0570],
[-0.0209, -0.4864, 0.8198, 1.1385, -0.3014, 0.0324, 0.2430, 0.3191,
0.1529, -1.1248, 0.2166, -0.1728, 0.6455, -0.8241, 0.3455, 0.1014,
0.3104, -0.5890, -1.5751, 1.0247, -0.5266, 0.5779, 0.1120, -0.8913,
-1.2297, -0.3089, -1.2772, 0.7984, -0.3051, 1.1217, -1.8258, -0.2479,
0.1087, -0.0614, 0.3057, -1.4438, -1.1894, -0.1585, -0.2005, 0.6369,
0.2338]], device='cuda:0')
the torch-trt 2.5 & image ghcr.io/pytorch/tensorrt/torch_tensorrt release_2.5 6f60df77ae91
it has error,please give me the release whl to install in my project
@apbose Could you please help expedite the positioning? Our project has been delayed for a long time in introducing this trt feature. thanks~~~
@apbose can you help me solve this problem ? your code may be the original code, is not the newer code
Are you using your own docker image and using torchtrt docker image as the base image?
On Tue, Nov 5, 2024, 5:32 PM yjjinjie @.***> wrote:
@apbose https://github.com/apbose can you help me solve this problem ? your code may be the original code, is not the newer code
— Reply to this email directly, view it on GitHub https://github.com/pytorch/TensorRT/issues/3140#issuecomment-2458523951, or unsubscribe https://github.com/notifications/unsubscribe-auth/AKRJMR7LPPCKFCHYC7RC2QTZ7FWURAVCNFSM6AAAAABNROACW2VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDINJYGUZDGOJVGE . You are receiving this because you were mentioned.Message ID: @.***>
I just use 1) docker pull ghcr.io/pytorch/tensorrt/torch_tensorrt:release_2.5 2) run the code
3) then get error
@apbose
@apbose can you see the issues, I think you use the original code, not my newer code
@peri044 please use the below code and run in ghcr.io/pytorch/tensorrt/torch_tensorrt:release_2.5 image,it gets error
import torch
import torch_tensorrt
from typing import Optional, Sequence,Dict,List
from torch.nn import functional as F
from torch import nn
@torch.fx.wrap
def _get_dict(grouped_features_keys: List[str], args:List[torch.Tensor])->Dict[str, torch.Tensor]:
if len(grouped_features_keys) != len(args):
raise ValueError(
"The number of grouped_features_keys must match "
"the number of arguments."
)
grouped_features = {
key: value for key, value in zip(grouped_features_keys, args)
}
return grouped_features
@torch.fx.wrap
def _arange(end: int, device: torch.device) -> torch.Tensor:
return torch.arange(end, device=device)
class MatMul(torch.nn.Module):
def __init__(self):
super().__init__()
self.keys = ["query","sequence","sequence_length"]
attn_mlp= {'hidden_units': [256, 64], 'dropout_ratio': [], 'activation': 'nn.ReLU', 'use_bn': False}
self.linear1 = nn.Linear(41*4, 256)
self.linear2 = nn.Linear(256, 64)
self.linear = nn.Linear(64, 1)
#def forward(self, *args1: List[torch.Tensor]):
def forward(self, args0, args1, args2):
"""Forward the module."""
# use predict to avoid trace error in self._output_to_prediction(y)
#return self.predict(args1)
return self.predict(args0, args1, args2)
#def predict(self, args: List[torch.Tensor]):
def predict(self, args0, args1, args2):
#grouped_features= _get_dict(self.keys, args)
#query = grouped_features["query"]
#sequence = grouped_features["sequence"]
#sequence_length = grouped_features["sequence_length"]
query = args0
sequence = args1
sequence_length = args2
max_seq_length = sequence.size(1)
sequence_mask = _arange(
max_seq_length, device=sequence_length.device
).unsqueeze(0) < sequence_length.unsqueeze(1)
queries = query.unsqueeze(1).expand(-1, max_seq_length, -1)
attn_input = torch.cat(
[queries, sequence, queries - sequence, queries * sequence], dim=-1
)
#attn_output = self.mlp(attn_input)
attn_output = self.linear1(attn_input)
attn_output = self.linear2(attn_output)
print(attn_output.shape)
attn_output = self.linear(attn_output)
attn_output = attn_output.transpose(1, 2)
padding = torch.ones_like(attn_output) * (-(2**32) + 1)
scores = torch.where(sequence_mask.unsqueeze(1), attn_output, padding)
scores = F.softmax(scores, dim=-1)
return torch.matmul(scores, sequence).squeeze(1)
#return padding
return attn_output
model = MatMul().eval().cuda()
a=torch.randn(2, 41).cuda()
b=torch.randn(2, 2,41).cuda()
c=torch.randn(2).cuda()
d=torch.tensor(2)
torch._dynamo.mark_dynamic(a, 0, min=1, max=8196)
torch._dynamo.mark_dynamic(b, 0, min=1, max=8196)
torch._dynamo.mark_dynamic(b, 1, min=1, max=50)
torch._dynamo.mark_dynamic(c, 0, min=1, max=8196)
# torch._dynamo.mark_dynamic(d, 0, min=1, max=8196)
inputs = [a, b,c]
print(model(*inputs))
# seq_len = torch.export.Dim("seq_len", min=1, max=10)
# dynamic_shapes=({2: seq_len}, {2: seq_len})
# Export the model first with custom dynamic shape constraints
from torch.fx import symbolic_trace
model = symbolic_trace(model)
inputs_dy = []
compile_spec = {
"inputs": [
torch_tensorrt.Input(
min_shape=(1, 41),
opt_shape=(512, 41),
max_shape=(8196, 41),
),
torch_tensorrt.Input(
min_shape=(1, 1,41),
opt_shape=(512, 2, 41),
max_shape=(8196,50, 41),
),
torch_tensorrt.Input(
min_shape=(1,),
opt_shape=(512,),
max_shape=(8196,),
)
],
"enabled_precisions": {torch.half},
"ir": "dynamo",
}
inputs_dy.append(
torch_tensorrt.Input(
min_shape=(1, 41),
opt_shape=(512, 41),
max_shape=(8196, 41),
name="query",
)
)
inputs_dy.append(
torch_tensorrt.Input(
min_shape=(1, 1,41),
opt_shape=(512, 2, 41),
max_shape=(8196,50, 41),
name="sequence",
)
)
inputs_dy.append(
torch_tensorrt.Input(
min_shape=(1,),
opt_shape=(512,),
max_shape=(8196,),
name="sequence_length",
)
)
print("the inputs_dy is!!!", inputs_dy)
print("the star inputs_dy", *inputs_dy)
with torch_tensorrt.logging.graphs():
trt_gm = torch_tensorrt.compile(
model,
**compile_spec, min_block_size=1,
cache_built_engines = False,
reuse_cached_engines = False)
print(trt_gm)
# exp_program = torch.export.export(model, (*inputs,))
# trt_gm = torch_tensorrt.dynamo.compile(exp_program, inputs,assume_dynamic_shape_support=True,
# allow_shape_tensors=True,min_block_size=2)
# Run inference
print(trt_gm(*inputs))
# trt_gm = symbolic_trace(trt_gm)
trt_gm = torch.jit.trace(trt_gm,
example_inputs=(a,b,c),
strict=False)
scripted_model = torch.jit.script(trt_gm)
scripted_model.save("./scripted_model_trt.pt")
model_gpu = torch.jit.load(
"./scripted_model_trt.pt", map_location="cuda:0"
)
from torch.profiler import ProfilerActivity, profile, record_function
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
record_shapes=True,
) as prof:
with record_function("model_inference"):
model_gpu(*inputs)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=100))
print("load:",model_gpu(*inputs))
ok trying now, could repro the error with the additional layers. I was trying the old code before which was missing the mlp layers. The error seems to come from those.
Tried a couple of experiments
seq_len_a_zero = torch.export.Dim("seq_len_a_zero", min=1, max=8196)
seq_len_b_zero = torch.export.Dim("seq_len_b_zero", min=1, max=8196)
seq_len_b_one = torch.export.Dim("seq_len_b_one", min=1, max=50)
seq_len_c_zero = torch.export.Dim("seq_len_c_zero", min=1, max=8196)
dynamic_shapes=({0:seq_len_a_zero}, {0:seq_len_b_zero, 1: seq_len_b_one}, {0:seq_len_c_zero})
exp_program = torch.export.export(model, tuple(inputs), dynamic_shapes=dynamic_shapes)
trt_gm = torch_tensorrt.dynamo.compile(exp_program, inputs)
# Run inference
out= trt_gm(*inputs)
This gives me -
The values of seq_len_b_zero = L['args1'].size()[0] and seq_len_a_zero = L['args0'].size()[0] must always be equal.
The values of seq_len_c_zero = L['args2'].size()[0] and seq_len_a_zero = L['args0'].size()[0] must always be equal
which means the torch export would want the seq dimension to be equal. The below
# seq_len_b_one = torch.export.Dim("seq_len_b_one", min=1, max=50)
# dynamic_shapes=({}, {1: seq_len_b_one}, {})
goes past the above but again results in
full_default = torch.ops.aten.full.default([2, 1, s0], 1, pin_memory = False, device = device(type='cuda', index=0), dtype = torch.float32)
NameError: name 's0' is not defined
full_default = torch.ops.aten.full.default([s0, 1, s2], 1, pin_memory = False, device = device(type='cuda', index=0), dtype = torch.float32)
NameError: name 's0' is not defined
Looking into this further.
@apbose yes. I also tried the dynamic_shapes too,it has the same error--NameError: name 's0' is not defined.
you can use these to solve first error The values of seq_len_b_zero = L['args1'].size()[0] and seq_len_a_zero = L['args0'].size()[0] must always be equal. use the same dim _seq_len_azero, dynamic_shapes=({0:seq_len_a_zero}, {0:seq_len_a_zero, 1: seq_len_b_one}, {0:seq_len_a_zero})
I think you can see the difference between the trt2.5 and nvcr.io/nvidia/pytorch:24.09-py3 ,becase the nvcr.io/nvidia/pytorch:24.09-py3 trt is ok,but it has no release whl
Aah ok, thanks for pointing it out @yjjinjie . So you mean the above example passes for nvcr.io/nvidia/pytorch:24.09-py3?
Ok interesting looks like it is passing there
The issue is coming from the lowering pass replace_full_like_with_full
here https://github.com/pytorch/TensorRT/blob/main/py/torch_tensorrt/dynamo/lowering/passes/replace_full_like_with_full.py
where it gets the shape from the input_tensor meta data. Since this is dynamic shape, it gets s0,1,s2
which is undefined. They would be determined at runtime
@apbose yes, when to fix this issue?
working on the fix will raise PR by next Monday
@apbose thanks. when the pr merged,can you give me the release whl which is compatible with torch2.5.0?
Raised #3289. Yeah ok, I can help you with that. Locally can cherry pick this PR on top of 2.5 to create the compatible wheel.
@apbose yes. I just Manually modify the code, it's ok
@apbose in my real project, I use the trt scripted model, then predict,Occasional Anomaly in Accuracy?
@apbose this is my project,I just pull 2 pr to add torch-tensorrt. What could be the possible reasons?
I use torch_tensorrt.runtime.set_multi_device_safe_mode(True), It can reduce the frequency of errors, but they still occur occasionally.
https://github.com/alibaba/TorchEasyRec/pull/30/files https://github.com/alibaba/TorchEasyRec/pull/32
but when I predict in test_multi_tower_with_fg_train_eval_export_trt, the accuracy is Occasional Anomaly...
Could you please provide me a bit more context on what are the two PRs which you have pulled? Also could you give simple repro code of what the test_multi_tower_with_fg_train_eval_export_trt
is doing?
@apbose the program is so large, I need some time to give code. but I find there is a random error , and also has random accuracy.
I just set the dynamic shap is
batch = torch.export.Dim("batch", min=1, max=2048)
seq = torch.export.Dim("seq_len" + str(i), min=1, max=100)
but the trt get the len is 50,and it is random error. What can lead to randomness within TRT?
ERROR: [Torch-TensorRT] - IExecutionContext::inferShapes: Error Code 7: Internal Error ([SLICE]-[aten_ops.expand.default]-[din_towers.1/expand_5]: ISliceLayer has out of bounds access on axis 0 Condition '<' violated: 67 >= 50. Instruction: CHECK_LESS 2048 50.)
RuntimeError: [Error thrown at core/runtime/execute_engine.cpp:253] Expected nbNames == 0 to be true but got false
[default0]:The shapes of the inputs:
@apbose I make some experiments in 4 machine.
and find :
echo 'options nvidia NVreg_EnableGpuFirmware=1' > /etc/modprobe.d/nvidia-gsp.conf
when I enableGPS, the accuracy is OK. -------GSP Firmware Version : 535.161.08 when I disableGPS, the accuracy is incorrect. --------GSP Firmware Version : N/A
why torch-tensorrt is related with GSP?
@apbose when I slove the GSP problem,it's accuracy is random incorrect, 1% probablity?what's reason may casue random accuracy? multi-stream can i disable it?
Hi @yjjinjie, some questions
Bug Description
when I use dynamic shape in trt, will raise error,
the static shape is ok.just delete these
To Reproduce
Steps to reproduce the behavior:
the env: