Open yjjinjie opened 2 weeks ago
@peri044 can you look at this in the context of dynamo, I think we are just waiting on UserObjects to be supported in trace
@yjjinjie this is unlikely to ever be supported in TS as it is in maintenance mode. Dynamo + ExportedProgram can support this pending some features from PyTorch
This is the example I would expect to work once user objects are supported
import torch.nn
import torch_tensorrt
class MySubmodule(torch.nn.Module):
def __init__(self):
super(MySubmodule, self).__init__()
self.layer = torch.nn.Linear(10, 10)
def forward(self, x):
return self.layer(x)
class MyMod(torch.nn.Module):
def __init__(self):
super(MyMod, self).__init__()
self.submod = MySubmodule()
def forward(self, x):
return self.submod(x)
def patch_submod(mod):
mod.submod = torch_tensorrt.compile(mod.submod, ir="dynamo",inputs=[
torch_tensorrt.Input(shape=(1, 10)),
],
min_block_size=1)
if __name__ == "__main__":
model = MyMod()
model.to("cuda")
patch_submod(model)
exported_program = torch_tensorrt.dynamo.trace(model, arg_inputs=[torch.zeros(1, 10).to("cuda")])
mod = exported_program.module()
mod(torch.zeros(1, 10).cuda())
print(exported_program.graph)
torch.save(exported_program, "test.pt")
Currently fails with
torch._dynamo.exc.Unsupported: call_function args: ListVariable(length=1) UserDefinedObjectVariable(ScriptObject)
Yes, in our actual scenario, because our code framework is quite complex and involves some conditionals, if we were to directly use the dynamo mode in the TRT conversion stage, we would also encounter these types of conditional statements. such as:
def forward(self, *args: List[torch.Tensor]) -> Dict[str, torch.Tensor]:
"""Forward the module."""
if len(self.grouped_features_keys) != len(args):
raise ValueError(
"The number of grouped_features_keys must match "
"the number of arguments."
)
grouped_features = {
key: value for key, value in zip(self.grouped_features_keys, args)
}
tower_outputs = []
for k, tower_mlp in self.towers.items():
tower_outputs.append(tower_mlp(grouped_features[k]))
for tower_din in self.din_towers:
tower_outputs.append(tower_din(grouped_features))
tower_output = torch.cat(tower_outputs, dim=-1)
if self._model_config.HasField("final"):
tower_output = self.final_mlp(tower_output)
y = self.output_mlp(tower_output)
return self._output_to_prediction(y)
if use the dynamo: just in torch_tensorrt.compile may raise these error:
[default0]:[rank0]: Traceback (most recent call last):
[default0]:[rank0]: File "/larec/tzrec/export.py", line 29, in <module>
[default0]:[rank0]: export(
[default0]:[rank0]: File "/larec/tzrec/main.py", line 1008, in export
[default0]:[rank0]: _script_model(
[default0]:[rank0]: File "/larec/tzrec/main.py", line 838, in _script_model
[default0]:[rank0]: dense_layer_trt = trt_convert(dense, [*values_list_cuda])
[default0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[default0]:[rank0]: File "/larec/tzrec/acc/utils.py", line 166, in trt_convert
[default0]:[rank0]: optimized_model = torch_tensorrt.compile(
[default0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^
[default0]:[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch_tensorrt/_compile.py", line 248, in compile
[default0]:[rank0]: exp_program = dynamo_trace(module, torchtrt_inputs, **kwargs)
[default0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[default0]:[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch_tensorrt/dynamo/_tracer.py", line 81, in trace
[default0]:[rank0]: exp_program = export(mod, tuple(torch_inputs), dynamic_shapes=tuple(dynamic_shapes))
[default0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[default0]:[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/export/__init__.py", line 174, in export
[default0]:[rank0]: return _export(
[default0]:[rank0]: ^^^^^^^^
[default0]:[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/export/_trace.py", line 945, in wrapper
[default0]:[rank0]: raise e
[default0]:[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/export/_trace.py", line 928, in wrapper
[default0]:[rank0]: ep = fn(*args, **kwargs)
[default0]:[rank0]: ^^^^^^^^^^^^^^^^^^^
[default0]:[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/export/exported_program.py", line 89, in wrapper
[default0]:[rank0]: return fn(*args, **kwargs)
[default0]:[rank0]: ^^^^^^^^^^^^^^^^^^^
[default0]:[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/export/_trace.py", line 1455, in _export
[default0]:[rank0]: aten_export_artifact = export_func(
[default0]:[rank0]: ^^^^^^^^^^^^
[default0]:[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/export/_trace.py", line 1060, in _strict_export
[default0]:[rank0]: gm_torch_level = _export_to_torch_ir(
[default0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^
[default0]:[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/export/_trace.py", line 512, in _export_to_torch_ir
[default0]:[rank0]: gm_torch_level, _ = torch._dynamo.export(
[default0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^
[default0]:[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 1379, in inner
[default0]:[rank0]: result_traced = opt_f(*args, **kwargs)
[default0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^
[default0]:[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[default0]:[rank0]: return self._call_impl(*args, **kwargs)
[default0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[default0]:[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[default0]:[rank0]: return forward_call(*args, **kwargs)
[default0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[default0]:[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 433, in _fn
[default0]:[rank0]: return fn(*args, **kwargs)
[default0]:[rank0]: ^^^^^^^^^^^^^^^^^^^
[default0]:[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[default0]:[rank0]: return self._call_impl(*args, **kwargs)
[default0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[default0]:[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[default0]:[rank0]: return forward_call(*args, **kwargs)
[default0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[default0]:[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 1116, in __call__
[default0]:[rank0]: return self._torchdynamo_orig_callable(
[default0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[default0]:[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 472, in __call__
[default0]:[rank0]: return _compile(
[default0]:[rank0]: ^^^^^^^^^
[default0]:[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/_utils_internal.py", line 84, in wrapper_function
[default0]:[rank0]: return StrobelightCompileTimeProfiler.profile_compile_time(
[default0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[default0]:[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/_strobelight/compile_time_profiler.py", line 129, in profile_compile_time
[default0]:[rank0]: return func(*args, **kwargs)
[default0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^
[default0]:[rank0]: File "/opt/conda/lib/python3.11/contextlib.py", line 81, in inner
[default0]:[rank0]: return func(*args, **kwds)
[default0]:[rank0]: ^^^^^^^^^^^^^^^^^^^
[default0]:[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 817, in _compile
[default0]:[rank0]: guarded_code = compile_inner(code, one_graph, hooks, transform)
[default0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[default0]:[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 231, in time_wrapper
[default0]:[rank0]: r = func(*args, **kwargs)
[default0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^
[default0]:[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 636, in compile_inner
[default0]:[rank0]: out_code = transform_code_object(code, transform)
[default0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[default0]:[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/bytecode_transformation.py", line 1185, in transform_code_object
[default0]:[rank0]: transformations(instructions, code_options)
[default0]:[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 178, in _fn
[default0]:[rank0]: return fn(*args, **kwargs)
[default0]:[rank0]: ^^^^^^^^^^^^^^^^^^^
[default0]:[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 582, in transform
[default0]:[rank0]: tracer.run()
[default0]:[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2451, in run
[default0]:[rank0]: super().run()
[default0]:[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 893, in run
[default0]:[rank0]: while self.step():
[default0]:[rank0]: ^^^^^^^^^^^
[default0]:[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 805, in step
[default0]:[rank0]: self.dispatch_table[inst.opcode](self, inst)
[default0]:[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 499, in wrapper
[default0]:[rank0]: return inner_fn(self, inst)
[default0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^
[default0]:[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1500, in CALL_FUNCTION_EX
[default0]:[rank0]: self.call_function(fn, argsvars.items, kwargsvars)
[default0]:[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 743, in call_function
[default0]:[rank0]: self.push(fn.call_function(self, args, kwargs))
[default0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[default0]:[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 344, in call_function
[default0]:[rank0]: return super().call_function(tx, args, kwargs)
[default0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[default0]:[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 293, in call_function
[default0]:[rank0]: return super().call_function(tx, args, kwargs)
[default0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[default0]:[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 90, in call_function
[default0]:[rank0]: return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
[default0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[default0]:[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 749, in inline_user_function_return
[default0]:[rank0]: return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
[default0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[default0]:[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2666, in inline_call
[default0]:[rank0]: return cls.inline_call_(parent, func, args, kwargs)
[default0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[default0]:[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2782, in inline_call_
[default0]:[rank0]: tracer.run()
[default0]:[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 893, in run
[default0]:[rank0]: while self.step():
[default0]:[rank0]: ^^^^^^^^^^^
[default0]:[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 805, in step
[default0]:[rank0]: self.dispatch_table[inst.opcode](self, inst)
[default0]:[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 499, in wrapper
[default0]:[rank0]: return inner_fn(self, inst)
[default0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^
[default0]:[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2059, in CALL
[default0]:[rank0]: self.call_function(fn, args, kwargs)
[default0]:[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 743, in call_function
[default0]:[rank0]: self.push(fn.call_function(self, args, kwargs))
[default0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[default0]:[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/variables/misc.py", line 680, in call_function
[default0]:[rank0]: return self.obj.call_method(tx, self.name, args, kwargs)
[default0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[default0]:[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/variables/user_defined.py", line 649, in call_method
[default0]:[rank0]: return super().call_method(tx, name, args, kwargs)
[default0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[default0]:[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/variables/base.py", line 320, in call_method
[default0]:[rank0]: unimplemented(f"call_method {self} {name} {args} {kwargs}")
[default0]:[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/exc.py", line 221, in unimplemented
[default0]:[rank0]: raise Unsupported(msg)
[default0]:[rank0]: torch._dynamo.exc.Unsupported: call_method UserDefinedObjectVariable(MultiTowerDIN) HasField [ConstantVariable()] {}
[default0]:
[default0]:[rank0]: from user code:
[default0]:[rank0]: File "/larec/tzrec/models/multi_tower_din.py", line 100, in forward
[default0]:[rank0]: return self.predict(*args)
[default0]:[rank0]: File "/larec/tzrec/models/multi_tower_din.py", line 90, in predict
[default0]:[rank0]: if self._model_config.HasField("final"):
[default0]:
[default0]:[rank0]: Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
[default0]:
Therefore, we use torch.jit.trace + trt_torch_script instead.
@narendasan
I also encountered another bug:
1) when I use dynamic input: just like this https://github.com/pytorch/TensorRT/issues/2334
inputs.append(
torch_tensorrt.Input(
min_shape=[1, 2, 41],
opt_shape=[512, 40, 41],
max_shape=[1024, 50, 41],
name="seq.sequence",
)
)
def forward(self, sequence_embedded: Dict[str, torch.Tensor]) -> torch.Tensor:
"""Forward the module."""
query = sequence_embedded[self._query_name]
sequence = sequence_embedded[self._sequence_name]
sequence_length = sequence_embedded[self._sequence_length_name]
max_seq_length = sequence.size(1)
sequence_mask = _arange(
max_seq_length, device=sequence_length.device
).unsqueeze(0) < sequence_length.unsqueeze(1)
if self._query_dim < self._sequence_dim:
query = F.pad(query, (0, self._sequence_dim - self._query_dim))
queries = query.unsqueeze(1).expand(-1, max_seq_length, -1)
attn_input = torch.cat(
[queries, sequence, queries - sequence, queries * sequence], dim=-1
)
attn_output = self.mlp(attn_input)
attn_output = self.linear(attn_output)
attn_output = attn_output.transpose(1, 2)
padding = torch.ones_like(attn_output) * (-(2**32) + 1)
scores = torch.where(sequence_mask.unsqueeze(1), attn_output, padding)
scores = F.softmax(scores, dim=-1)
return torch.matmul(scores, sequence).squeeze(1)
raise the error
TorchScript Conversion Context] - Evaluating %30 : int = aten::size(%sequence.1, %141), scope: __module.din_towers.0 # /larec/tzrec/modules/sequence.py:77:0
[default0]:WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size without setting allow_shape_tensors
[default0]:DEBUG: [Torch-TensorRT TorchScript Conversion Context] - Found the value to be: -1
[default0]:DEBUG: [Torch-TensorRT TorchScript Conversion Context] - Evaluating %32 : Tensor = aten::arange(%30, %33, %33, %34, %35), scope: __module.din_towers.0 # /larec/tzrec/modules/sequence.py:16:0
[default0]:[rank0]: Traceback (most recent call last):
[default0]:[rank0]: File "/larec/tzrec/export.py", line 29, in <module>
[default0]:[rank0]: export(
[default0]:[rank0]: File "/larec/tzrec/main.py", line 972, in export
[default0]:[rank0]: _script_model(
[default0]:[rank0]: File "/larec/tzrec/main.py", line 825, in _script_model
[default0]:[rank0]: dense_layer_trt = trt_convert(dense_layer, [*inputs])
[default0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[default0]:[rank0]: File "/larec/tzrec/acc/utils.py", line 166, in trt_convert
[default0]:[rank0]: optimized_model = torch_tensorrt.compile(
[default0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^
[default0]:[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch_tensorrt/_compile.py", line 208, in compile
[default0]:[rank0]: compiled_ts_module: torch.jit.ScriptModule = torchscript_compile(
[default0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^
[default0]:[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch_tensorrt/ts/_compiler.py", line 156, in compile
[default0]:[rank0]: compiled_cpp_mod = _C.compile_graph(module._c, _parse_compile_spec(spec))
[default0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[default0]:[rank0]: RuntimeError: upper bound and larger bound inconsistent with step sign
when I use allow_shape_tensors=True,
:DEBUG: [Torch-TensorRT TorchScript Conversion Context] - Evaluating %32 : Tensor = aten::arange(%30, %33, %33, %34, %35), scope: __module.din_towers.0 # /larec/tzrec/modules/sequence.py:16:0
[default0]:[rank0]: Traceback (most recent call last):
[default0]:[rank0]: File "/larec/tzrec/export.py", line 29, in <module>
[default0]:[rank0]: export(
[default0]:[rank0]: File "/larec/tzrec/main.py", line 972, in export
[default0]:[rank0]: _script_model(
[default0]:[rank0]: File "/larec/tzrec/main.py", line 825, in _script_model
[default0]:[rank0]: dense_layer_trt = trt_convert(dense_layer, [*inputs])
[default0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[default0]:[rank0]: File "/larec/tzrec/acc/utils.py", line 166, in trt_convert
[default0]:[rank0]: optimized_model = torch_tensorrt.compile(
[default0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^
[default0]:[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch_tensorrt/_compile.py", line 208, in compile
[default0]:[rank0]: compiled_ts_module: torch.jit.ScriptModule = torchscript_compile(
[default0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^
[default0]:[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch_tensorrt/ts/_compiler.py", line 156, in compile
[default0]:[rank0]: compiled_cpp_mod = _C.compile_graph(module._c, _parse_compile_spec(spec))
[default0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[default0]:[rank0]: RuntimeError: [Error thrown at core/conversion/var/Var.cpp:127] Expected isIValue() to be true but got false
[default0]:[rank0]: Requested IValue from Var, however Var type is nvinfer1::ITensor
aten::arange support the nvinfer1::ITensor?
@narendasan I also encountered another bug:
the same model: I use the static shape,
the input of seqence_length=5, the trt model acc==the origin model;
but when I use seqence_length=50, the trt model acc is not equal to the origin model (-0.720 VS -0.516)
I don't know if it's caused by multi-stream or dynamic some other reason. can I disable the multi-stream or dynamic shape ?
Would torch.compile work in your usecase? It is able to support conditionals and you can use engine caching to short cut setup. Its going to be unlikely we add any improvements to torchscript.
In torchscript if there is no dynamic inputs there should be no dynamic shapes. Multistream (at least how it is used for us, where TRT has non default execution) cannot be turned off since TRT requires this.
You can file an issue for the accuracy issue with a repro and we can try to figure out what is going on
I find a solution in my code: use symbolic_trace(model) + torch.export.export + torch_tensorrt.compile(ir="dynamo") to replace torch.jit.trace(model)+ torch_tensorrt.compile(ir="ts") , the new solution acc is correct,and I can save emb+dense_trt in one model.
import torch.nn
import torch_tensorrt
class MySubmodule(torch.nn.Module):
def __init__(self):
super(MySubmodule, self).__init__()
self.layer = torch.nn.Linear(10, 10)
def forward(self, x):
return self.layer(x)
class MyMod(torch.nn.Module):
def __init__(self):
super(MyMod, self).__init__()
self.submod = MySubmodule()
def forward(self, x):
return self.submod(x)
if __name__ == "__main__":
model = MyMod()
model.to("cuda")
from torchrec.fx import symbolic_trace
model = symbolic_trace(model)
exp_program = torch.export.export(model, (torch.zeros(1, 10).cuda(),))
trt_gm = torch_tensorrt.dynamo.compile(exp_program, torch.zeros(1, 10).cuda(),min_block_size=1)
trt_gm = torch.jit.trace(trt_gm,
example_inputs=(torch.zeros(1, 10).cuda()),
strict=False)
scripted_model = torch.jit.script(trt_gm)
scripted_model.save("./scripted_model_trt.pt")
model_gpu = torch.jit.load(
"./scripted_model_trt.pt", map_location="cuda:0"
)
from torch.profiler import ProfilerActivity, profile, record_function
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
record_shapes=True,
) as prof:
with record_function("model_inference"):
res = model_gpu(torch.zeros(1, 10).cuda())
print("final res:",res)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=100))
Simplified version of actual code:
import torch
import torch_tensorrt
from typing import Optional, Sequence,Dict,List
@torch.fx.wrap
def _get_dict(grouped_features_keys: List[str], args:List[torch.Tensor])->Dict[str, torch.Tensor]:
if len(grouped_features_keys) != len(args):
raise ValueError(
"The number of grouped_features_keys must match "
"the number of arguments."
)
grouped_features = {
key: value for key, value in zip(grouped_features_keys, args)
}
return grouped_features
class MatMul(torch.nn.Module):
def __init__(self):
super().__init__()
self.keys = ["query","key"]
def forward(self, *args1: List[torch.Tensor]):
"""Forward the module."""
# use predict to avoid trace error in self._output_to_prediction(y)
return self.predict(args1)
def predict(self, args: List[torch.Tensor]):
grouped_features= _get_dict(self.keys, args)
query = grouped_features["query"]
key = grouped_features["key"]
attn_weight = torch.matmul(query, key.transpose(-1, -2))
return attn_weight
model = MatMul().eval().cuda()
inputs = [torch.randn(1, 12, 7, 64).cuda(), torch.randn(1, 12, 7, 64).cuda()]
print(model(*inputs)[0][0])
seq_len = torch.export.Dim("seq_len", min=1, max=10)
# dynamic_shapes=({2: seq_len}, {2: seq_len})
# Export the model first with custom dynamic shape constraints
from torchrec.fx import symbolic_trace
model = symbolic_trace(model)
exp_program = torch.export.export(model, (*inputs,))
trt_gm = torch_tensorrt.dynamo.compile(exp_program, inputs)
# Run inference
print(trt_gm(*inputs)[0][0])
# trt_gm = symbolic_trace(trt_gm)
trt_gm = torch.jit.trace(trt_gm,
example_inputs=(torch.randn(1, 12, 7, 64).cuda(), torch.randn(1, 12, 7, 64).cuda()),
strict=False)
scripted_model = torch.jit.script(trt_gm)
scripted_model.save("./scripted_model_trt.pt")
model_gpu = torch.jit.load(
"./scripted_model_trt.pt", map_location="cuda:0"
)
print("load:",model_gpu(*inputs)[0][0])
class MatMul2(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, args: List[torch.Tensor]):
query = args[0]
key = args[1]
attn_weight = torch.matmul(query, key.transpose(-1, -2))
return attn_weight
model = MatMul2().eval().cuda()
inputs = [torch.randn(1, 12, 7, 64).cuda(), torch.randn(1, 12, 7, 64).cuda()]
seq_len = torch.export.Dim("seq_len", min=1, max=10)
# dynamic_shapes=({2: seq_len}, {2: seq_len})
# Export the model first with custom dynamic shape constraints
exp_program = torch.export.export(model, (inputs,))
# ERROR
trt_gm = torch_tensorrt.dynamo.compile(exp_program, inputs)
# Run inference
print(trt_gm(*inputs))
when I save the model, I must use torch.jit.trace(model) + torch.save ,but torch.jit.trace don't support torch.device, in my usecase, I want to use symbolic_trace + torch.save ,but symbolic_trace don't support *args in a loop.
--->use *args as forward ,when I use the symbolic_trace, @narendasan
File "/larec/tzrec/tests/test3.py", line 46, in <module>
trt_gm = symbolic_trace(trt_gm)
^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torchrec/fx/tracer.py", line 161, in symbolic_trace
graph = tracer.trace(root, concrete_args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torchrec/fx/tracer.py", line 86, in trace
graph = super().trace(
^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 600, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/fx/_symbolic_trace.py", line 823, in trace
(self.create_arg(fn(*args)),),
^^^^^^^^^
File "<eval_with_key>.40", line 6, in forward
File "/opt/conda/lib/python3.11/site-packages/torch/fx/_symbolic_trace.py", line 800, in module_call_wrapper
return self.call_module(mod, forward, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/fx/_symbolic_trace.py", line 518, in call_module
ret_val = forward(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/fx/_symbolic_trace.py", line 793, in forward
return _orig_module_call(mod, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py", line 166, in forward
input_tensors: List[torch.Tensor] = [
^
File "/opt/conda/lib/python3.11/site-packages/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py", line 167, in <listcomp>
(i if isinstance(i, torch.Tensor) else torch.tensor(i).cuda())
^^^^^^^^^^^^^^^
TypeError: `__cuda_array_interface__` must be a dict
I want to know when I use args as input, the trt_gm = torch_tensorrt.dynamo.compile(exp_program, inputs) is error, how can I update the code : @narendasan error is
ValueError: Trying to flatten user inputs with exported input tree spec:
TreeSpec(tuple, None, [TreeSpec(tuple, None, [TreeSpec(list, None, [*,
*])]),
TreeSpec(dict, [], [])])
but actually got inputs with tree spec of:
TreeSpec(tuple, None, [TreeSpec(tuple, None, [*,
*]),
TreeSpec(dict, [], [])])
I want to use symbolic_trace to support torch.device ,can you help me to solve it?
Bug Description
But it raises exception: RuntimeError: method.qualname() == QualifiedName(selfClass->name()->qualifiedName(), methodName)INTERNAL ASSERT FAILED at "../torch/csrc/jit/serialization/python_print.cpp":1105, please report a bug to PyTorch.
if I use dynamo:
error:
the env: