pytorch / TensorRT

PyTorch/TorchScript/FX compiler for NVIDIA GPUs using TensorRT
https://pytorch.org/TensorRT
BSD 3-Clause "New" or "Revised" License
2.5k stars 345 forks source link

🐛 [Bug] Encountered bug when using Torch-TensorRT (convert part of the model) #3127

Open yjjinjie opened 2 weeks ago

yjjinjie commented 2 weeks ago

Bug Description

But it raises exception: RuntimeError: method.qualname() == QualifiedName(selfClass->name()->qualifiedName(), methodName)INTERNAL ASSERT FAILED at "../torch/csrc/jit/serialization/python_print.cpp":1105, please report a bug to PyTorch.

import torch.nn
import torch_tensorrt

class MySubmodule(torch.nn.Module):
    def __init__(self):
        super(MySubmodule, self).__init__()
        self.layer = torch.nn.Linear(10, 10)

    def forward(self, x):
        return self.layer(x)

class MyMod(torch.nn.Module):
    def __init__(self):
        super(MyMod, self).__init__()
        self.submod = MySubmodule()
        self.submod = torch_tensorrt.compile(self.submod, ir="ts",inputs=[
            torch_tensorrt.Input(shape=(1, 10)),
        ])
        print(self.submod.__qualname__)

    def forward(self, x):
        return self.submod(x)

if __name__ == "__main__":
    model = MyMod()
    scripted = torch.jit.script(model)
    scripted(torch.zeros(1, 10).cuda())

    scripted.save("test.pt")

if I use dynamo:

import torch.nn
import torch_tensorrt

class MySubmodule(torch.nn.Module):
    def __init__(self):
        super(MySubmodule, self).__init__()
        self.layer = torch.nn.Linear(10, 10)

    def forward(self, x):
        return self.layer(x)

class MyMod(torch.nn.Module):
    def __init__(self):
        super(MyMod, self).__init__()
        self.submod = MySubmodule()
        self.submod = torch_tensorrt.compile(self.submod, ir="dynamo",inputs=[
            torch_tensorrt.Input(shape=(1, 10)),
        ])
        print(self.submod.__qualname__)

    def forward(self, x):
        return self.submod(x)

if __name__ == "__main__":
    model = MyMod()
    scripted = torch.jit.script(model)
    scripted(torch.zeros(1, 10).cuda())

    scripted.save("test.pt")

error:

File "/larec/tzrec/tests/test.py", line 28, in <module>
    model = MyMod()
            ^^^^^^^
  File "/larec/tzrec/tests/test.py", line 18, in __init__
    self.submod = torch_tensorrt.compile(self.submod, ir="dynamo",inputs=[
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch_tensorrt/_compile.py", line 248, in compile
    exp_program = dynamo_trace(module, torchtrt_inputs, **kwargs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch_tensorrt/dynamo/_tracer.py", line 81, in trace
    exp_program = export(mod, tuple(torch_inputs), dynamic_shapes=tuple(dynamic_shapes))
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/export/__init__.py", line 174, in export
    return _export(
           ^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/export/_trace.py", line 945, in wrapper
    raise e
  File "/opt/conda/lib/python3.11/site-packages/torch/export/_trace.py", line 928, in wrapper
    ep = fn(*args, **kwargs)
         ^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/export/exported_program.py", line 89, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/export/_trace.py", line 1455, in _export
    aten_export_artifact = export_func(
                           ^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/export/_trace.py", line 1060, in _strict_export
    gm_torch_level = _export_to_torch_ir(
                     ^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/export/_trace.py", line 512, in _export_to_torch_ir
    gm_torch_level, _ = torch._dynamo.export(
                        ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 1379, in inner
    result_traced = opt_f(*args, **kwargs)
                    ^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 433, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 1116, in __call__
    return self._torchdynamo_orig_callable(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 472, in __call__
    return _compile(
           ^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_utils_internal.py", line 84, in wrapper_function
    return StrobelightCompileTimeProfiler.profile_compile_time(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_strobelight/compile_time_profiler.py", line 129, in profile_compile_time
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/contextlib.py", line 81, in inner
    return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 817, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 231, in time_wrapper
    r = func(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 636, in compile_inner
    out_code = transform_code_object(code, transform)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/bytecode_transformation.py", line 1185, in transform_code_object
    transformations(instructions, code_options)
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 178, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 582, in transform
    tracer.run()
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2451, in run
    super().run()
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 893, in run
    while self.step():
          ^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 805, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 499, in wrapper
    return inner_fn(self, inst)
           ^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2059, in CALL
    self.call_function(fn, args, kwargs)
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 743, in call_function
    self.push(fn.call_function(self, args, kwargs))
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/variables/nn_module.py", line 409, in call_function
    return wrap_fx_proxy(
           ^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/variables/builder.py", line 1713, in wrap_fx_proxy
    return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/variables/builder.py", line 1798, in wrap_fx_proxy_cls
    example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 1853, in get_fake_value
    raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 1785, in get_fake_value
    ret_val = wrap_fake_exception(
              ^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 1300, in wrap_fake_exception
    return fn()
           ^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 1786, in <lambda>
    lambda: run_node(tx.output, node, args, kwargs, nnmodule)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 1921, in run_node
    raise RuntimeError(make_error_message(e)).with_traceback(
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 1908, in run_node
    return nnmodule(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/linear.py", line 117, in forward
    return F.linear(input, self.weight, self.bias)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/utils/_stats.py", line 21, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1061, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1450, in dispatch
    return self._cached_dispatch_impl(func, types, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1145, in _cached_dispatch_impl
    output = self._dispatch_impl(func, types, args, kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1690, in _dispatch_impl
    return decomposition_table[func](*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_prims_common/wrappers.py", line 266, in _fn
    result = fn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_decomp/decompositions.py", line 79, in inner
    r = f(*tree_map(increase_prec, args), **tree_map(increase_prec, kwargs))
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_decomp/decompositions.py", line 1437, in addmm
    out = alpha * torch.mm(mat1, mat2)
                  ^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/utils/_stats.py", line 21, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1061, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1450, in dispatch
    return self._cached_dispatch_impl(func, types, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1145, in _cached_dispatch_impl
    output = self._dispatch_impl(func, types, args, kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1765, in _dispatch_impl
    self.wrap_meta_outputs_with_default_device_logic(
  File "/opt/conda/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1875, in wrap_meta_outputs_with_default_device_logic
    return tree_map(wrap, r)
           ^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/utils/_pytree.py", line 948, in tree_map
    return treespec.unflatten(map(func, *flat_args))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/utils/_pytree.py", line 787, in unflatten
    leaves = list(leaves)
             ^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1853, in wrap
    ) = FakeTensor._find_common_device(func, flat_args)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 775, in _find_common_device
    merge_devices(arg)
  File "/opt/conda/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 770, in merge_devices
    raise RuntimeError(
torch._dynamo.exc.TorchRuntimeError: Failed running call_module L__self___layer(*(FakeTensor(..., device='cuda:0', size=(1, 10)),), **{}):
Unhandled FakeTensor Device Propagation for aten.mm.default, found two different devices cuda:0, cpu

from user code:
   File "/larec/tzrec/tests/test.py", line 11, in forward
    return self.layer(x)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

the env:

CPU(s):                          104
On-line CPU(s) list:             0-103
Vendor ID:                       GenuineIntel
Model name:                      Intel(R) Xeon(R) Platinum 8269CY CPU @ 2.50GHz
CPU family:                      6
Model:                           85
Thread(s) per core:              2
Core(s) per socket:              26
Socket(s):                       2
Stepping:                        7
CPU max MHz:                     3800.0000
CPU min MHz:                     1200.0000
BogoMIPS:                        5000.00
Flags:                           fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cdp_l3 invpcid_single intel_ppin ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm mpx rdt_a avx512f avx512dq rdseed adx smap clflushopt clwb intel_pt avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local dtherm ida arat pln pts pku ospke avx512_vnni md_clear flush_l1d arch_capabilities
Virtualization:                  VT-x
L1d cache:                       1.6 MiB (52 instances)
L1i cache:                       1.6 MiB (52 instances)
L2 cache:                        52 MiB (52 instances)
L3 cache:                        71.5 MiB (2 instances)
NUMA node(s):                    1
NUMA node0 CPU(s):               0-103
Vulnerability Itlb multihit:     KVM: Mitigation: Split huge pages
Vulnerability L1tf:              Not affected
Vulnerability Mds:               Not affected
Vulnerability Meltdown:          Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:        Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:        Mitigation; Enhanced IBRS, IBPB conditional, RSB filling
Vulnerability Tsx async abort:   Mitigation; TSX disabled

Versions of relevant libraries:
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.26.4
[pip3] optree==0.12.1
[pip3] torch==2.4.0
[pip3] torch_tensorrt==2.4.0
[pip3] torchaudio==2.4.0
[pip3] torchelastic==0.2.2
[pip3] torchmetrics==1.0.3
[pip3] torchrec==0.8.0+cu121
[pip3] torchvision==0.19.0
[pip3] triton==3.0.0
[conda] blas                      1.0                         mkl
[conda] ffmpeg                    4.3                  hf484d3e_0    pytorch
[conda] libjpeg-turbo             2.0.0                h9bf148f_0    pytorch
[conda] mkl                       2023.1.0         h213fc3f_46344
[conda] mkl-service               2.4.0           py311h5eee18b_1
[conda] mkl_fft                   1.3.8           py311h5eee18b_0
[conda] mkl_random                1.2.4           py311hdb19cb5_0
[conda] numpy                     1.26.4          py311h08b1b3b_0
[conda] numpy-base                1.26.4          py311hf175353_0
[conda] optree                    0.12.1                   pypi_0    pypi
[conda] pytorch                   2.4.0           py3.11_cuda12.1_cudnn9.1.0_0    pytorch
[conda] pytorch-cuda              12.1                 ha16c6d3_5    pytorch
[conda] pytorch-mutex             1.0                        cuda    pytorch
[conda] torch-tensorrt            2.4.0                    pypi_0    pypi
[conda] torchaudio                2.4.0               py311_cu121    pytorch
[conda] torchelastic              0.2.2                    pypi_0    pypi
[conda] torchmetrics              1.0.3                    pypi_0    pypi
[conda] torchrec                  0.8.0+cu121              pypi_0    pypi
[conda] torchtriton               3.0.0                     py311    pytorch
[conda] torchvision               0.19.0              py311_cu121    pytorch
narendasan commented 2 weeks ago

@peri044 can you look at this in the context of dynamo, I think we are just waiting on UserObjects to be supported in trace

@yjjinjie this is unlikely to ever be supported in TS as it is in maintenance mode. Dynamo + ExportedProgram can support this pending some features from PyTorch

narendasan commented 2 weeks ago

This is the example I would expect to work once user objects are supported

import torch.nn
import torch_tensorrt

class MySubmodule(torch.nn.Module):
    def __init__(self):
        super(MySubmodule, self).__init__()
        self.layer = torch.nn.Linear(10, 10)

    def forward(self, x):
        return self.layer(x)

class MyMod(torch.nn.Module):
    def __init__(self):
        super(MyMod, self).__init__()
        self.submod = MySubmodule()

    def forward(self, x):
        return self.submod(x)

def patch_submod(mod):
    mod.submod = torch_tensorrt.compile(mod.submod, ir="dynamo",inputs=[
        torch_tensorrt.Input(shape=(1, 10)),
    ],
    min_block_size=1)

if __name__ == "__main__":
    model = MyMod()
    model.to("cuda")
    patch_submod(model)
    exported_program = torch_tensorrt.dynamo.trace(model, arg_inputs=[torch.zeros(1, 10).to("cuda")])
    mod = exported_program.module()
    mod(torch.zeros(1, 10).cuda())

    print(exported_program.graph)

    torch.save(exported_program, "test.pt")

Currently fails with

torch._dynamo.exc.Unsupported: call_function args: ListVariable(length=1) UserDefinedObjectVariable(ScriptObject)
yjjinjie commented 2 weeks ago

Yes, in our actual scenario, because our code framework is quite complex and involves some conditionals, if we were to directly use the dynamo mode in the TRT conversion stage, we would also encounter these types of conditional statements. such as:

def forward(self, *args: List[torch.Tensor]) -> Dict[str, torch.Tensor]:
        """Forward the module."""
        if len(self.grouped_features_keys) != len(args):
            raise ValueError(
                "The number of grouped_features_keys must match "
                "the number of arguments."
            )
        grouped_features = {
            key: value for key, value in zip(self.grouped_features_keys, args)
        }
        tower_outputs = []
        for k, tower_mlp in self.towers.items():
            tower_outputs.append(tower_mlp(grouped_features[k]))

        for tower_din in self.din_towers:
            tower_outputs.append(tower_din(grouped_features))

        tower_output = torch.cat(tower_outputs, dim=-1)
        if self._model_config.HasField("final"):
            tower_output = self.final_mlp(tower_output)

        y = self.output_mlp(tower_output)
        return self._output_to_prediction(y)

if use the dynamo: just in torch_tensorrt.compile may raise these error:

[default0]:[rank0]: Traceback (most recent call last):
[default0]:[rank0]:   File "/larec/tzrec/export.py", line 29, in <module>
[default0]:[rank0]:     export(
[default0]:[rank0]:   File "/larec/tzrec/main.py", line 1008, in export
[default0]:[rank0]:     _script_model(
[default0]:[rank0]:   File "/larec/tzrec/main.py", line 838, in _script_model
[default0]:[rank0]:     dense_layer_trt = trt_convert(dense, [*values_list_cuda])
[default0]:[rank0]:                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[default0]:[rank0]:   File "/larec/tzrec/acc/utils.py", line 166, in trt_convert
[default0]:[rank0]:     optimized_model = torch_tensorrt.compile(
[default0]:[rank0]:                       ^^^^^^^^^^^^^^^^^^^^^^^
[default0]:[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch_tensorrt/_compile.py", line 248, in compile
[default0]:[rank0]:     exp_program = dynamo_trace(module, torchtrt_inputs, **kwargs)
[default0]:[rank0]:                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[default0]:[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch_tensorrt/dynamo/_tracer.py", line 81, in trace
[default0]:[rank0]:     exp_program = export(mod, tuple(torch_inputs), dynamic_shapes=tuple(dynamic_shapes))
[default0]:[rank0]:                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[default0]:[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/export/__init__.py", line 174, in export
[default0]:[rank0]:     return _export(
[default0]:[rank0]:            ^^^^^^^^
[default0]:[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/export/_trace.py", line 945, in wrapper
[default0]:[rank0]:     raise e
[default0]:[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/export/_trace.py", line 928, in wrapper
[default0]:[rank0]:     ep = fn(*args, **kwargs)
[default0]:[rank0]:          ^^^^^^^^^^^^^^^^^^^
[default0]:[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/export/exported_program.py", line 89, in wrapper
[default0]:[rank0]:     return fn(*args, **kwargs)
[default0]:[rank0]:            ^^^^^^^^^^^^^^^^^^^
[default0]:[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/export/_trace.py", line 1455, in _export
[default0]:[rank0]:     aten_export_artifact = export_func(
[default0]:[rank0]:                            ^^^^^^^^^^^^
[default0]:[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/export/_trace.py", line 1060, in _strict_export
[default0]:[rank0]:     gm_torch_level = _export_to_torch_ir(
[default0]:[rank0]:                      ^^^^^^^^^^^^^^^^^^^^
[default0]:[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/export/_trace.py", line 512, in _export_to_torch_ir
[default0]:[rank0]:     gm_torch_level, _ = torch._dynamo.export(
[default0]:[rank0]:                         ^^^^^^^^^^^^^^^^^^^^^
[default0]:[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 1379, in inner
[default0]:[rank0]:     result_traced = opt_f(*args, **kwargs)
[default0]:[rank0]:                     ^^^^^^^^^^^^^^^^^^^^^^
[default0]:[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[default0]:[rank0]:     return self._call_impl(*args, **kwargs)
[default0]:[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[default0]:[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[default0]:[rank0]:     return forward_call(*args, **kwargs)
[default0]:[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[default0]:[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 433, in _fn
[default0]:[rank0]:     return fn(*args, **kwargs)
[default0]:[rank0]:            ^^^^^^^^^^^^^^^^^^^
[default0]:[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[default0]:[rank0]:     return self._call_impl(*args, **kwargs)
[default0]:[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[default0]:[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[default0]:[rank0]:     return forward_call(*args, **kwargs)
[default0]:[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[default0]:[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 1116, in __call__
[default0]:[rank0]:     return self._torchdynamo_orig_callable(
[default0]:[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[default0]:[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 472, in __call__
[default0]:[rank0]:     return _compile(
[default0]:[rank0]:            ^^^^^^^^^
[default0]:[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_utils_internal.py", line 84, in wrapper_function
[default0]:[rank0]:     return StrobelightCompileTimeProfiler.profile_compile_time(
[default0]:[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[default0]:[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_strobelight/compile_time_profiler.py", line 129, in profile_compile_time
[default0]:[rank0]:     return func(*args, **kwargs)
[default0]:[rank0]:            ^^^^^^^^^^^^^^^^^^^^^
[default0]:[rank0]:   File "/opt/conda/lib/python3.11/contextlib.py", line 81, in inner
[default0]:[rank0]:     return func(*args, **kwds)
[default0]:[rank0]:            ^^^^^^^^^^^^^^^^^^^
[default0]:[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 817, in _compile
[default0]:[rank0]:     guarded_code = compile_inner(code, one_graph, hooks, transform)
[default0]:[rank0]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[default0]:[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 231, in time_wrapper
[default0]:[rank0]:     r = func(*args, **kwargs)
[default0]:[rank0]:         ^^^^^^^^^^^^^^^^^^^^^
[default0]:[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 636, in compile_inner
[default0]:[rank0]:     out_code = transform_code_object(code, transform)
[default0]:[rank0]:                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[default0]:[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/bytecode_transformation.py", line 1185, in transform_code_object
[default0]:[rank0]:     transformations(instructions, code_options)
[default0]:[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 178, in _fn
[default0]:[rank0]:     return fn(*args, **kwargs)
[default0]:[rank0]:            ^^^^^^^^^^^^^^^^^^^
[default0]:[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 582, in transform
[default0]:[rank0]:     tracer.run()
[default0]:[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2451, in run
[default0]:[rank0]:     super().run()
[default0]:[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 893, in run
[default0]:[rank0]:     while self.step():
[default0]:[rank0]:           ^^^^^^^^^^^
[default0]:[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 805, in step
[default0]:[rank0]:     self.dispatch_table[inst.opcode](self, inst)
[default0]:[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 499, in wrapper
[default0]:[rank0]:     return inner_fn(self, inst)
[default0]:[rank0]:            ^^^^^^^^^^^^^^^^^^^^
[default0]:[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1500, in CALL_FUNCTION_EX
[default0]:[rank0]:     self.call_function(fn, argsvars.items, kwargsvars)
[default0]:[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 743, in call_function
[default0]:[rank0]:     self.push(fn.call_function(self, args, kwargs))
[default0]:[rank0]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[default0]:[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 344, in call_function
[default0]:[rank0]:     return super().call_function(tx, args, kwargs)
[default0]:[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[default0]:[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 293, in call_function
[default0]:[rank0]:     return super().call_function(tx, args, kwargs)
[default0]:[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[default0]:[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 90, in call_function
[default0]:[rank0]:     return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
[default0]:[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[default0]:[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 749, in inline_user_function_return
[default0]:[rank0]:     return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
[default0]:[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[default0]:[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2666, in inline_call
[default0]:[rank0]:     return cls.inline_call_(parent, func, args, kwargs)
[default0]:[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[default0]:[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2782, in inline_call_
[default0]:[rank0]:     tracer.run()
[default0]:[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 893, in run
[default0]:[rank0]:     while self.step():
[default0]:[rank0]:           ^^^^^^^^^^^
[default0]:[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 805, in step
[default0]:[rank0]:     self.dispatch_table[inst.opcode](self, inst)
[default0]:[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 499, in wrapper
[default0]:[rank0]:     return inner_fn(self, inst)
[default0]:[rank0]:            ^^^^^^^^^^^^^^^^^^^^
[default0]:[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2059, in CALL
[default0]:[rank0]:     self.call_function(fn, args, kwargs)
[default0]:[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 743, in call_function
[default0]:[rank0]:     self.push(fn.call_function(self, args, kwargs))
[default0]:[rank0]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[default0]:[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/variables/misc.py", line 680, in call_function
[default0]:[rank0]:     return self.obj.call_method(tx, self.name, args, kwargs)
[default0]:[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[default0]:[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/variables/user_defined.py", line 649, in call_method
[default0]:[rank0]:     return super().call_method(tx, name, args, kwargs)
[default0]:[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[default0]:[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/variables/base.py", line 320, in call_method
[default0]:[rank0]:     unimplemented(f"call_method {self} {name} {args} {kwargs}")
[default0]:[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/exc.py", line 221, in unimplemented
[default0]:[rank0]:     raise Unsupported(msg)
[default0]:[rank0]: torch._dynamo.exc.Unsupported: call_method UserDefinedObjectVariable(MultiTowerDIN) HasField [ConstantVariable()] {}
[default0]:
[default0]:[rank0]: from user code:
[default0]:[rank0]:    File "/larec/tzrec/models/multi_tower_din.py", line 100, in forward
[default0]:[rank0]:     return self.predict(*args)
[default0]:[rank0]:   File "/larec/tzrec/models/multi_tower_din.py", line 90, in predict
[default0]:[rank0]:     if self._model_config.HasField("final"):
[default0]:
[default0]:[rank0]: Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
[default0]:

Therefore, we use torch.jit.trace + trt_torch_script instead.

yjjinjie commented 2 weeks ago

@narendasan

I also encountered another bug:

1) when I use dynamic input: just like this https://github.com/pytorch/TensorRT/issues/2334

inputs.append(
                torch_tensorrt.Input(
                    min_shape=[1, 2, 41],
                    opt_shape=[512, 40, 41],
                    max_shape=[1024, 50, 41],
                    name="seq.sequence",

                )
            )
def forward(self, sequence_embedded: Dict[str, torch.Tensor]) -> torch.Tensor:
        """Forward the module."""
        query = sequence_embedded[self._query_name]
        sequence = sequence_embedded[self._sequence_name]
        sequence_length = sequence_embedded[self._sequence_length_name]
        max_seq_length = sequence.size(1)
        sequence_mask = _arange(
            max_seq_length, device=sequence_length.device
        ).unsqueeze(0) < sequence_length.unsqueeze(1)

        if self._query_dim < self._sequence_dim:
            query = F.pad(query, (0, self._sequence_dim - self._query_dim))
        queries = query.unsqueeze(1).expand(-1, max_seq_length, -1)

        attn_input = torch.cat(
            [queries, sequence, queries - sequence, queries * sequence], dim=-1
        )
        attn_output = self.mlp(attn_input)
        attn_output = self.linear(attn_output)
        attn_output = attn_output.transpose(1, 2)

        padding = torch.ones_like(attn_output) * (-(2**32) + 1)
        scores = torch.where(sequence_mask.unsqueeze(1), attn_output, padding)
        scores = F.softmax(scores, dim=-1)
        return torch.matmul(scores, sequence).squeeze(1)

raise the error

TorchScript Conversion Context] - Evaluating %30 : int = aten::size(%sequence.1, %141), scope: __module.din_towers.0 # /larec/tzrec/modules/sequence.py:77:0
[default0]:WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size without setting allow_shape_tensors
[default0]:DEBUG: [Torch-TensorRT TorchScript Conversion Context] - Found the value to be: -1
[default0]:DEBUG: [Torch-TensorRT TorchScript Conversion Context] - Evaluating %32 : Tensor = aten::arange(%30, %33, %33, %34, %35), scope: __module.din_towers.0 # /larec/tzrec/modules/sequence.py:16:0
[default0]:[rank0]: Traceback (most recent call last):
[default0]:[rank0]:   File "/larec/tzrec/export.py", line 29, in <module>
[default0]:[rank0]:     export(
[default0]:[rank0]:   File "/larec/tzrec/main.py", line 972, in export
[default0]:[rank0]:     _script_model(
[default0]:[rank0]:   File "/larec/tzrec/main.py", line 825, in _script_model
[default0]:[rank0]:     dense_layer_trt = trt_convert(dense_layer, [*inputs])
[default0]:[rank0]:                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[default0]:[rank0]:   File "/larec/tzrec/acc/utils.py", line 166, in trt_convert
[default0]:[rank0]:     optimized_model = torch_tensorrt.compile(
[default0]:[rank0]:                       ^^^^^^^^^^^^^^^^^^^^^^^
[default0]:[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch_tensorrt/_compile.py", line 208, in compile
[default0]:[rank0]:     compiled_ts_module: torch.jit.ScriptModule = torchscript_compile(
[default0]:[rank0]:                                                  ^^^^^^^^^^^^^^^^^^^^
[default0]:[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch_tensorrt/ts/_compiler.py", line 156, in compile
[default0]:[rank0]:     compiled_cpp_mod = _C.compile_graph(module._c, _parse_compile_spec(spec))
[default0]:[rank0]:                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[default0]:[rank0]: RuntimeError: upper bound and larger bound inconsistent with step sign

when I use allow_shape_tensors=True,

:DEBUG: [Torch-TensorRT TorchScript Conversion Context] - Evaluating %32 : Tensor = aten::arange(%30, %33, %33, %34, %35), scope: __module.din_towers.0 # /larec/tzrec/modules/sequence.py:16:0
[default0]:[rank0]: Traceback (most recent call last):
[default0]:[rank0]:   File "/larec/tzrec/export.py", line 29, in <module>
[default0]:[rank0]:     export(
[default0]:[rank0]:   File "/larec/tzrec/main.py", line 972, in export
[default0]:[rank0]:     _script_model(
[default0]:[rank0]:   File "/larec/tzrec/main.py", line 825, in _script_model
[default0]:[rank0]:     dense_layer_trt = trt_convert(dense_layer, [*inputs])
[default0]:[rank0]:                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[default0]:[rank0]:   File "/larec/tzrec/acc/utils.py", line 166, in trt_convert
[default0]:[rank0]:     optimized_model = torch_tensorrt.compile(
[default0]:[rank0]:                       ^^^^^^^^^^^^^^^^^^^^^^^
[default0]:[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch_tensorrt/_compile.py", line 208, in compile
[default0]:[rank0]:     compiled_ts_module: torch.jit.ScriptModule = torchscript_compile(
[default0]:[rank0]:                                                  ^^^^^^^^^^^^^^^^^^^^
[default0]:[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch_tensorrt/ts/_compiler.py", line 156, in compile
[default0]:[rank0]:     compiled_cpp_mod = _C.compile_graph(module._c, _parse_compile_spec(spec))
[default0]:[rank0]:                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[default0]:[rank0]: RuntimeError: [Error thrown at core/conversion/var/Var.cpp:127] Expected isIValue() to be true but got false
[default0]:[rank0]: Requested IValue from Var, however Var type is nvinfer1::ITensor

aten::arange support the nvinfer1::ITensor?

yjjinjie commented 2 weeks ago

@narendasan I also encountered another bug:

the same model: I use the static shape,

the input of seqence_length=5, the trt model acc==the origin model;

but when I use seqence_length=50, the trt model acc is not equal to the origin model (-0.720 VS -0.516)

I don't know if it's caused by multi-stream or dynamic some other reason. can I disable the multi-stream or dynamic shape ?

narendasan commented 2 weeks ago

Would torch.compile work in your usecase? It is able to support conditionals and you can use engine caching to short cut setup. Its going to be unlikely we add any improvements to torchscript.

In torchscript if there is no dynamic inputs there should be no dynamic shapes. Multistream (at least how it is used for us, where TRT has non default execution) cannot be turned off since TRT requires this.

You can file an issue for the accuracy issue with a repro and we can try to figure out what is going on

yjjinjie commented 1 week ago

I find a solution in my code: use symbolic_trace(model) + torch.export.export + torch_tensorrt.compile(ir="dynamo") to replace torch.jit.trace(model)+ torch_tensorrt.compile(ir="ts") , the new solution acc is correct,and I can save emb+dense_trt in one model.

import torch.nn
import torch_tensorrt

class MySubmodule(torch.nn.Module):
    def __init__(self):
        super(MySubmodule, self).__init__()
        self.layer = torch.nn.Linear(10, 10)

    def forward(self, x):
        return self.layer(x)

class MyMod(torch.nn.Module):
    def __init__(self):
        super(MyMod, self).__init__()
        self.submod = MySubmodule()

    def forward(self, x):
        return self.submod(x)

if __name__ == "__main__":
    model = MyMod()
    model.to("cuda")
    from torchrec.fx import symbolic_trace
    model = symbolic_trace(model)
    exp_program = torch.export.export(model, (torch.zeros(1, 10).cuda(),))
    trt_gm = torch_tensorrt.dynamo.compile(exp_program, torch.zeros(1, 10).cuda(),min_block_size=1)

    trt_gm = torch.jit.trace(trt_gm,
                                     example_inputs=(torch.zeros(1, 10).cuda()), 
                                     strict=False)

    scripted_model = torch.jit.script(trt_gm)
    scripted_model.save("./scripted_model_trt.pt")

    model_gpu = torch.jit.load(
        "./scripted_model_trt.pt", map_location="cuda:0"
    )
    from torch.profiler import ProfilerActivity, profile, record_function

    with profile(
        activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
        record_shapes=True,
    ) as prof:
        with record_function("model_inference"):
            res = model_gpu(torch.zeros(1, 10).cuda())
            print("final res:",res)
    print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=100))
yjjinjie commented 1 week ago

Simplified version of actual code:

import torch
import torch_tensorrt
from typing import Optional, Sequence,Dict,List

@torch.fx.wrap
def _get_dict(grouped_features_keys: List[str], args:List[torch.Tensor])->Dict[str, torch.Tensor]:
    if len(grouped_features_keys) != len(args):
            raise ValueError(
                "The number of grouped_features_keys must match "
                "the number of arguments."
            )
    grouped_features = {
        key: value for key, value in zip(grouped_features_keys, args)
    }
    return grouped_features

class MatMul(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.keys = ["query","key"]

    def forward(self, *args1: List[torch.Tensor]):
        """Forward the module."""
        # use predict to avoid trace error in self._output_to_prediction(y)
        return self.predict(args1)

    def predict(self, args: List[torch.Tensor]):
        grouped_features= _get_dict(self.keys, args)
        query = grouped_features["query"]
        key = grouped_features["key"]
        attn_weight = torch.matmul(query, key.transpose(-1, -2))
        return attn_weight

model = MatMul().eval().cuda()
inputs = [torch.randn(1, 12, 7, 64).cuda(), torch.randn(1, 12, 7, 64).cuda()]
print(model(*inputs)[0][0])
seq_len = torch.export.Dim("seq_len", min=1, max=10)
# dynamic_shapes=({2: seq_len}, {2: seq_len})
# Export the model first with custom dynamic shape constraints
from torchrec.fx import symbolic_trace
model = symbolic_trace(model)
exp_program = torch.export.export(model, (*inputs,))
trt_gm = torch_tensorrt.dynamo.compile(exp_program, inputs)
# Run inference
print(trt_gm(*inputs)[0][0])
# trt_gm = symbolic_trace(trt_gm)
trt_gm = torch.jit.trace(trt_gm,
                                     example_inputs=(torch.randn(1, 12, 7, 64).cuda(), torch.randn(1, 12, 7, 64).cuda()), 
                                     strict=False)

scripted_model = torch.jit.script(trt_gm)
scripted_model.save("./scripted_model_trt.pt")

model_gpu = torch.jit.load(
    "./scripted_model_trt.pt", map_location="cuda:0"
)
print("load:",model_gpu(*inputs)[0][0])

class MatMul2(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, args: List[torch.Tensor]):
        query = args[0]
        key = args[1]
        attn_weight = torch.matmul(query, key.transpose(-1, -2))
        return attn_weight

model = MatMul2().eval().cuda()
inputs = [torch.randn(1, 12, 7, 64).cuda(), torch.randn(1, 12, 7, 64).cuda()]
seq_len = torch.export.Dim("seq_len", min=1, max=10)
# dynamic_shapes=({2: seq_len}, {2: seq_len})
# Export the model first with custom dynamic shape constraints
exp_program = torch.export.export(model, (inputs,))
# ERROR
trt_gm = torch_tensorrt.dynamo.compile(exp_program, inputs)
# Run inference
print(trt_gm(*inputs))

when I save the model, I must use torch.jit.trace(model) + torch.save ,but torch.jit.trace don't support torch.device, in my usecase, I want to use symbolic_trace + torch.save ,but symbolic_trace don't support *args in a loop.

--->use *args as forward ,when I use the symbolic_trace, @narendasan

File "/larec/tzrec/tests/test3.py", line 46, in <module>
    trt_gm = symbolic_trace(trt_gm)
             ^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torchrec/fx/tracer.py", line 161, in symbolic_trace
    graph = tracer.trace(root, concrete_args)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torchrec/fx/tracer.py", line 86, in trace
    graph = super().trace(
            ^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 600, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/fx/_symbolic_trace.py", line 823, in trace
    (self.create_arg(fn(*args)),),
                     ^^^^^^^^^
  File "<eval_with_key>.40", line 6, in forward
  File "/opt/conda/lib/python3.11/site-packages/torch/fx/_symbolic_trace.py", line 800, in module_call_wrapper
    return self.call_module(mod, forward, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/fx/_symbolic_trace.py", line 518, in call_module
    ret_val = forward(*args, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/fx/_symbolic_trace.py", line 793, in forward
    return _orig_module_call(mod, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py", line 166, in forward
    input_tensors: List[torch.Tensor] = [
                                        ^
  File "/opt/conda/lib/python3.11/site-packages/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py", line 167, in <listcomp>
    (i if isinstance(i, torch.Tensor) else torch.tensor(i).cuda())
                                           ^^^^^^^^^^^^^^^
TypeError: `__cuda_array_interface__` must be a dict

I want to know when I use args as input, the trt_gm = torch_tensorrt.dynamo.compile(exp_program, inputs) is error, how can I update the code : @narendasan error is

ValueError: Trying to flatten user inputs with exported input tree spec:
TreeSpec(tuple, None, [TreeSpec(tuple, None, [TreeSpec(list, None, [*,
      *])]),
  TreeSpec(dict, [], [])])
but actually got inputs with tree spec of:
TreeSpec(tuple, None, [TreeSpec(tuple, None, [*,
    *]),
  TreeSpec(dict, [], [])])

I want to use symbolic_trace to support torch.device ,can you help me to solve it?