I would expect for torch to raise an exception when inference fails for any reason, such as wrong input tensor shape or wrong dtype. Instead, a warning is raised in console but the program continues successfully. This can have serious implications in production environments.
To Reproduce
I have a model compiled on float16 that accepts a static input shape of (1, 3, 538, 538)
>>> out = model(torch.zeros((1, 3, 500, 500), dtype=torch.float16, device="cuda"))
ERROR: [Torch-TensorRT] - 3: [executionContext.cpp::setInputShape::2020] Error Code 3: API Usage Error (Parameter check failed at: runtime/api/executionContext.cpp::setInputShape::2020, condition: engineDims.d[i] == dims.d[i]. Static dimension mismatch while setting input shape.
)
>>> # NO EXCEPTION IS RAISED
This is what happens if I pass a wrong dtype
>>> out = model(torch.zeros((1, 3, 538, 538), dtype=torch.float32, device="cuda"))
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/wizard/mambaforge/envs/remini/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript, serialized code (most recent call last):
File "code/__torch__/<model>.py", line 8, in forward
input_0: Tensor) -> Tensor:
__torch___<model>_trt_engine_ = self_1.__torch___<model>_trt_engine_
_0 = ops.tensorrt.execute_engine([input_0], __torch___<model>_trt_engine_)
~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
_1, = _0
return _1
Traceback of TorchScript, original code (most recent call last):
RuntimeError: [Error thrown at core/runtime/execute_engine.cpp:136] Expected inputs[i].dtype() == expected_type to be true but got false
Expected input tensors to have type Half, found type float
>>> # NO EXCEPTION IS RAISED
Expected behavior
An exception should be raised if the TensorRT Engine returns an error.
Environment
Build information about Torch-TensorRT can be found by turning on debug messages
Torch-TensorRT Version (e.g. 1.0.0): 1.4.0
PyTorch Version (e.g. 1.0): 2.0.1
CPU Architecture: x86_64
OS (e.g., Linux): LInux
How you installed PyTorch (conda, pip, libtorch, source): pip (custom whl)
Build command you used (if compiling from source): -
Are you using local sources or building from archives: -
Bug Description
I would expect for torch to raise an exception when inference fails for any reason, such as wrong input tensor shape or wrong dtype. Instead, a warning is raised in console but the program continues successfully. This can have serious implications in production environments.
To Reproduce
I have a model compiled on float16 that accepts a static input shape of
(1, 3, 538, 538)
This is what happens if I pass a wrong shape
This is what happens if I pass a wrong dtype
Expected behavior
An exception should be raised if the TensorRT Engine returns an error.
Environment
conda
,pip
,libtorch
, source): pip (custom whl)