pytorch / TensorRT

PyTorch/TorchScript/FX compiler for NVIDIA GPUs using TensorRT
BSD 3-Clause "New" or "Revised" License
2.57k stars 350 forks source link

🐛 [Bug] ValueError: Trying to flatten user inputs with exported input tree spec #3020

Open Hukongtao opened 3 months ago

Hukongtao commented 3 months ago

Bug Description

To Reproduce

The code comes from the official documentation:

import torch
import torch_tensorrt

class MatMul(torch.nn.Module):
    def __init__(self):

    def forward(self, query, key):
        attn_weight = torch.matmul(query, key.transpose(-1, -2))
        return attn_weight

model = MatMul().eval().cuda()
inputs = [torch.randn(1, 12, 7, 64).cuda(), torch.randn(1, 12, 7, 64).cuda()]
seq_len = torch.export.Dim("seq_len", min=1, max=10)
dynamic_shapes=({2: seq_len}, {2: seq_len})
# Export the model first with custom dynamic shape constraints
exp_program = torch.export.export(model, tuple(inputs), dynamic_shapes=dynamic_shapes)
trt_gm = torch_tensorrt.dynamo.compile(exp_program, [inputs])
# Run inference

Expected behavior

run successfully


Build information about Torch-TensorRT can be found by turning on debug messages

Additional context

INFO:torch_tensorrt [TensorRT Conversion Context]:[MemUsageStats] Peak memory usage of TRT CPU/GPU memory allocators: CPU 0 MiB, GPU 1 MiB
INFO:torch_tensorrt [TensorRT Conversion Context]:[MemUsageStats] Peak memory usage during Engine building and serialization: CPU: 7909 MiB
INFO:torch_tensorrt [TensorRT Conversion Context]:Serialized 26 bytes of code generator cache.
INFO:torch_tensorrt [TensorRT Conversion Context]:Serialized 81 timing cache entries
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Build TRT engine elapsed time: 0:00:00.846883
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT Engine uses: 123972 bytes of Memory
WARNING: [Torch-TensorRT] - CUDA lazy loading is not enabled. Enabling it can significantly reduce device memory usage and speed up TensorRT initialization. See "Lazy Loading" section of CUDA documentation
Traceback (most recent call last):
  File "/mnt/bn/hukongtao-infer-speed/mlx/users/", line 18, in <module>
    trt_gm = torch_tensorrt.dynamo.compile(exp_program, [inputs])
  File "/usr/local/lib/python3.9/dist-packages/torch_tensorrt/dynamo/", line 227, in compile
    trt_gm = compile_module(gm, inputs, settings)
  File "/usr/local/lib/python3.9/dist-packages/torch_tensorrt/dynamo/", line 421, in compile_module
    sample_outputs = gm(
  File "/usr/local/lib/python3.9/dist-packages/torch/fx/", line 737, in call_wrapped
    return self._wrapped_call(self, *args, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/torch/fx/", line 317, in __call__
    raise e
  File "/usr/local/lib/python3.9/dist-packages/torch/fx/", line 304, in __call__
    return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
  File "/usr/local/lib/python3.9/dist-packages/torch/nn/modules/", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/torch/nn/modules/", line 1561, in _call_impl
    args_kwargs_result = hook(self, args, kwargs)  # type: ignore[misc]
  File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/", line 451, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/", line 36, in inner
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/torch/export/", line 25, in _check_input_constraints_pre_hook
    raise ValueError(  # noqa: TRY200
ValueError: Trying to flatten user inputs with exported input tree spec: 
TreeSpec(tuple, None, [TreeSpec(tuple, None, [*,
  TreeSpec(dict, [], [])])
but actually got inputs with tree spec of: 
TreeSpec(tuple, None, [TreeSpec(tuple, None, [TreeSpec(list, None, [*,
  TreeSpec(dict, [], [])])
Hukongtao commented 3 months ago

Did the code examples provided by the official documentation also have errors?

peri044 commented 2 months ago

Thanks for catching this. The correct usage is as follows:

trt_gm = torch_tensorrt.dynamo.compile(exp_program, inputs)
# Run inference

inputs is already a list and [inputs] was being passed which is incorrect. Raised a PR here :