pytorch / TensorRT

PyTorch/TorchScript/FX compiler for NVIDIA GPUs using TensorRT
https://pytorch.org/TensorRT
BSD 3-Clause "New" or "Revised" License
2.6k stars 351 forks source link

🐛 [Bug] Failed to compile a FP16 QAT resnet #1468

Closed Njuapp closed 1 year ago

Njuapp commented 2 years ago

Bug Description

After trained a Q/DQ-inserted fake-quantized resnet, we did model.half() and then torch.jit.trace() to generate this torchscript model. However, it failed to compile.

To Reproduce

Steps to reproduce the behavior:

  1. Download the FP16 QAT checkpoint resnet50_quant.ts
  2. Run the following script
    
    import torch
    imgs = torch.zeros(32, 3, 224, 224).cuda().half()
    resnet50_model = torch.jit.load("resnet50_quant.ts").cuda()
    out = resnet50_model(imgs)
    for k, v in resnet50_model.state_dict().items():
    print(k, v.dtype, v.shape)

import torch_tensorrt

trt_model_int8 = torch_tensorrt.ts.compile(resnet50_model, inputs = [ torch_tensorrt.Input( min_shape = tuple([1, 3, 224, 224]), opt_shape = tuple([16, 3, 224, 224]), max_shape = tuple([32, 3, 224, 224]), dtype=torch.float16 ) ], enabled_precisions = torch.int8, workspace_size = 1 << 32, ) trt_model_int8.save("trt_64.pt") trt_model_int8 = torch.jit.load("trt_64.pt") trt_output = trt_model_int8(imgs) resnet50_outputs = resnet50_model(imgs) print(trt_output) print(resnet50_outputs)

print(trt_output[1].dtype) diff = abs(trt_output - resnet50_outputs) print('diff:', diff.mean())



## Expected behavior
Compile with no error and result matches.

## Environment

> Build information about Torch-TensorRT can be found by turning on debug messages

 - Torch-TensorRT Version (e.g. 1.0.0):  1.3.0a0+975f6387
 - PyTorch Version (e.g. 1.0): 1.13.0.dev20220921+cu116
 - CPU Architecture:
 - OS (e.g., Linux):
 - How you installed PyTorch (`conda`, `pip`, `libtorch`, source):
 - Build command you used (if compiling from source):
 - Are you using local sources or building from archives:
 - Python version: 3.9
 - CUDA version: 11.6
 - GPU models and configuration:
 - Any other relevant information:

## Additional context

<!-- Add any other context about the problem here. -->
Njuapp commented 1 year ago

Fixed when turning to trt 8.5.1