pytorch / TensorRT

PyTorch/TorchScript/FX compiler for NVIDIA GPUs using TensorRT
https://pytorch.org/TensorRT
BSD 3-Clause "New" or "Revised" License
2.6k stars 350 forks source link

❓ [Question] how to specify dynamic shape when using torch_tensorrt.save #3109

Closed Qi-Zha0 closed 1 month ago

Qi-Zha0 commented 2 months ago

❓ Question

I was following the documentation on compiling a model with dynamic input shape. When saving the compiled graph module (following this), the new torch_tensorrt.save(module, path, inputs) API requires inputs to be all tensors. How do I pass dynamic shapes to torch_tensorrt.save? Error: https://github.com/pytorch/TensorRT/blob/77278fe395d6ffdd456fd7a8a94852cd27ee63a9/py/torch_tensorrt/_compile.py#L420

import torch
import torch_tensorrt

model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True)
model.eval().cuda()
inputs = [torch_tensorrt.Input(min_shape=[1, 3, 224, 224],
                              opt_shape=[4, 3, 224, 224],
                              max_shape=[8, 3, 224, 224],
                              dtype=torch.float32)]
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs=inputs)
torch_tensorrt.save(trt_gm, "trt_gm.ep", inputs=inputs)
WARNING:torch_tensorrt.dynamo.conversion.aten_ops_converters:Unable to import quantization op. Please install modelopt library (https://github.com/NVIDIA/TensorRT-Model-Optimizer?tab=readme-ov-file#installation) to add support for compiling quantized models
INFO:torch_tensorrt.dynamo._compiler:Compilation Settings: CompilationSettings(enabled_precisions={<dtype.f32: 7>}, debug=False, workspace_size=0, min_block_size=5, torch_executed_ops=set(), pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_double=False, use_fast_partitioner=True, enable_experimental_decompositions=False, device=Device(type=DeviceType.GPU, gpu_id=0), require_full_compilation=False, disable_tf32=False, assume_dynamic_shape_support=False, sparse_weights=False, refit=False, engine_capability=<EngineCapability.STANDARD: 1>, num_avg_timing_iters=1, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=536870912, dryrun=False, hardware_compatible=False, timing_cache_path='/tmp/timing_cache.bin')

INFO:torch_tensorrt.dynamo._compiler:Partitioning the graph via the fast partitioner
INFO:torch_tensorrt [TensorRT Conversion Context]:[MemUsageChange] Init CUDA: CPU +1, GPU +0, now: CPU 449, GPU 1622 (MiB)
INFO:torch_tensorrt [TensorRT Conversion Context]:[MemUsageChange] Init builder kernel library: CPU +1622, GPU +288, now: CPU 2218, GPU 1910 (MiB)
WARNING:torch_tensorrt.dynamo.conversion.converter_utils:Detected unparseable type in node formatting: <class 'torch.SymInt'>
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT INetwork construction elapsed time: 0:00:00.609398
INFO:torch_tensorrt [TensorRT Conversion Context]:Global timing cache in use. Profiling results in this builder pass will be stored.
INFO:torch_tensorrt [TensorRT Conversion Context]:Detected 1 inputs and 1 output network tensors.
INFO:torch_tensorrt [TensorRT Conversion Context]:Total Host Persistent Memory: 343968
INFO:torch_tensorrt [TensorRT Conversion Context]:Total Device Persistent Memory: 7168
INFO:torch_tensorrt [TensorRT Conversion Context]:Total Scratch Memory: 6424576
INFO:torch_tensorrt [TensorRT Conversion Context]:[BlockAssignment] Started assigning block shifts. This will take 86 steps to complete.
INFO:torch_tensorrt [TensorRT Conversion Context]:[BlockAssignment] Algorithm ShiftNTopDown took 0.644934ms to assign 4 blocks to 86 nodes requiring 65830912 bytes.
INFO:torch_tensorrt [TensorRT Conversion Context]:Total Activation Memory: 65830912
INFO:torch_tensorrt [TensorRT Conversion Context]:Total Weights Memory: 127383968
INFO:torch_tensorrt [TensorRT Conversion Context]:Engine generation completed in 0.553365 seconds.
INFO:torch_tensorrt [TensorRT Conversion Context]:[MemUsageStats] Peak memory usage of TRT CPU/GPU memory allocators: CPU 16 MiB, GPU 129 MiB
INFO:torch_tensorrt [TensorRT Conversion Context]:[MemUsageStats] Peak memory usage during Engine building and serialization: CPU: 4064 MiB
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Build TRT engine elapsed time: 0:00:00.649827
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT Engine uses: 129675836 bytes of Memory
INFO:torch_tensorrt [TensorRT Conversion Context]:Serialized 26 bytes of code generator cache.
INFO:torch_tensorrt [TensorRT Conversion Context]:Serialized 292352 bytes of compilation cache.
INFO:torch_tensorrt [TensorRT Conversion Context]:Serialized 3744 timing cache entries
WARNING: [Torch-TensorRT] - Detected this engine is being instantitated in a multi-GPU system with multi-device safe mode disabled. For more on the implications of this as well as workarounds, see the linked documentation (https://pytorch.org/TensorRT/user_guide/runtime.html#multi-device-safe-mode)
Traceback (most recent call last):
  File "test.py", line 11, in <module>
    torch_tensorrt.save(trt_gm, "trt_gm.ep", inputs=inputs)
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "path/.venv/lib/python3.11/site-packages/torch_tensorrt/_compile.py", line 420, in save
    raise ValueError(
ValueError: Not all inputs provided are torch.tensors. Please provide torch.tensors as inputs

What you have already tried

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

Additional context

peri044 commented 2 months ago

Hello @Qi-Zha0 ,

You should pass torch tensor inputs to the save API. These inputs should be in the range (min_shape, opt_shape, max_shape). eg: torch_tensorrt.save(trt_gm, "trt_gm.ep", inputs=[torch.randn(4, 3, 224, 224).cuda()]). We shall update the documentation to make this more clear. Thanks !!

Qi-Zha0 commented 1 month ago

@peri044 Thank you for clarifying!