pytorch / TensorRT

PyTorch/TorchScript/FX compiler for NVIDIA GPUs using TensorRT
https://pytorch.org/TensorRT
BSD 3-Clause "New" or "Revised" License
2.6k stars 351 forks source link

❓ [Question] How do you properly deploy a quantized model with tensorrt #3267

Open Urania880519 opened 1 month ago

Urania880519 commented 1 month ago

❓ Question

I have a PTQ model and a QAT model trained with the official pytorch API following the quantization tutorial, and I wish to deploy them on TensorRT for inference. The model is metaformer-like using convolution layers as token mixer. One part of the quantized model looks like this: image

What you have already tried

I have tried different ways to make things work:

  1. the package torch2trt: there's huge problem with dynamic input. The dataset consists of different inputs (B,C,H,W) where H and W are not necessarily the same. There's a torch2trt-dynamic package but I think there are bugs in the plugins. The code basically looks like this: model_trt = torch2trt( model_fp32, [torch.randn(1, 11, 64, 64).to('cuda')], max_batch_size=batch_size, fp16_mode=False, int8_mode=True, calibrator= trainLoader, input_shapes=[(None, 11, None, None)] )
  2. torch.compile() with backends=tensorrt. When I was trying to compile the PTQ model, there's RuntimeError: quantized::conv2d (ONEDNN): data type of input should be QUint8. And when I was trying to use the QAT model, there's W1029 14:21:17.640402 139903289382080 torch/_dynamo/utils.py:1195] [2/0] Unsupported: quantized nyi in meta tensors with fake tensor propagation. Here's the code I used: trt_gm = torch.compile( model, dynamic= True, backend="tensorrt",)
  3. try to convert the torch model to an onnx model, then convert it into the trt engine. There are several problems in this case:
    • The onnx model is runs weirdly slow with onnx runtime. Furthermore, the loss calculated is extremely high. Here's an example: image

Build information about Torch-TensorRT can be found by turning on debug messages

Additional context

Personally I think the torch.compile() API is the most possible for me to successfully convert the quantized model since there's no performance drop. Does anyone has relevant experience on handling quantized model?

narendasan commented 4 weeks ago

Did you follow this tutorial? https://pytorch.org/TensorRT/tutorials/_rendered_examples/dynamo/vgg16_ptq.html

Urania880519 commented 4 weeks ago

@narendasan I've followed both the tutorial you provided and this one: https://pytorch.org/TensorRT/user_guide/dynamic_shapes.html#dynamic-shapes However, there's this error after finishing calibration(the calibration seemed successful and the loss was quite low) image image This is the code I used:

  quant_cfg = mtq.INT8_DEFAULT_CFG
  mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
  with torch.no_grad():
      with export_torch_mode():
          input_tensor = torch.randn((1, channels, 35, 35), dtype=torch.float32).to('cuda')
          height_dim = torch.export.Dim("height_dim", min=25, max=64)
          width_dim= torch.export.Dim("width_dim", min=25, max=64)
          dynamic_shapes = ({2: height_dim, 3: width_dim},)
          from torch.export._trace import _export
          exp_program = _export(model, (input_tensor,), dynamic_shapes= dynamic_shapes)
          trt_Qmodel = torchtrt.dynamo.compile(
                  exp_program,
                  inputs=[input_tensor],
                  enabled_precisions={torch.int8},
                  min_block_size=1,
                  debug=False,
                  assume_dynamic_shape_support= True
           )
narendasan commented 3 weeks ago

@lanluo-nvidia or @peri044 can you provide additional guidance here?

lanluo-nvidia commented 3 weeks ago

@Urania880519
if you could paste the full code, I can try to reproduce on my side to know what is the exact issue you are facing. Also the in8 quantization support was introduced in 2.5.0 version, if you could try with the 2.5.0 pytorch and torch_tensorrt.

in terms of dynamic shape support in torch_tensorrt, if you have Custom Dynamic Shape Constraints, please refer this tutorial: https://pytorch.org/TensorRT/user_guide/dynamic_shapes.html via torch.export.export()