NVIDIA / TensorRT

NVIDIA® TensorRT™ is an SDK for high-performance deep learning inference on NVIDIA GPUs. This repository contains the open source components of TensorRT.
https://developer.nvidia.com/tensorrt
Apache License 2.0
10.62k stars 2.11k forks source link

How to enable FP8 convolution in TensorRT 10.2 #3987

Closed junstar92 closed 2 months ago

junstar92 commented 3 months ago

Hello,

I am using TensorRT 10.2 and noticed that the normal FP8 convolution has been updated. However, when I try to use a simple QDQ + Conv model in ONNX, the FP8 convolution is not selected. Even timing FP8 tactics is not performed.

Here is the model I used. It was quantized by using TensorRT-Model-Optimizer. And I used H100 device.

image
lix19937 commented 3 months ago

How was this file(simple_conv_fp8.onnx) generated ?

yuanyao-nv commented 2 months ago

@junstar92 You might have to add the --stronglyTyped flag as well. cc: @nvpohanh

nvpohanh commented 2 months ago

Sorry, this is a bug in TRT 10.2. Please enable --stronglyTyped for now.

We will try to fix this issue in TRT 10.3

junstar92 commented 2 months ago

@nvpohanh Thank you for checking this issue. With --stronglyTyped flags, FP8 tactic is enabled.

But, I have another question about FP8 convolution. I tried to build ResNet18 and ResNet50, but, TensorRT cannot find any implementation for the first conv operation of ResNet. It seems there are no convolution operation tactics which have 3 in-channels. Does TensorRT 10.2 support overall ResNet18 or ResNet50 ?

Here is the error log.

[07/12/2024-05:24:18] [V] [TRT] =============== Computing costs for {ForeignNode[/fake_quantizer_7c435f9c02917de57484db91f86bbbaf/QuantizeLinear.../fake_quantizer_5d22154b57438817000ba0a1ea6159ca/DequantizeLinear]}
[07/12/2024-05:24:18] [V] [TRT] *************** Autotuning format combination: Float(150528,50176,224,1) -> Float(802816,12544,112,1) ***************
[07/12/2024-05:24:18] [V] [TRT] --------------- Timing Runner: {ForeignNode[/fake_quantizer_7c435f9c02917de57484db91f86bbbaf/QuantizeLinear.../fake_quantizer_5d22154b57438817000ba0a1ea6159ca/DequantizeLinear]} (Myelin[0x80000023])
[07/12/2024-05:24:18] [V] [TRT] [MemUsageChange] Subgraph create: CPU +0, GPU +0, now: CPU 2390, GPU 844 (MiB)
[07/12/2024-05:24:18] [E] Error[9]: Error Code: 9: Skipping tactic 0x0000000000000000 due to exception [autotuner.cpp:get_best_tactics:2061] Autotuner: no tactics to implement operation:
  131: corrltn: /L__self___conv1/Conv_output_0'_before_bias.1-(f32[16,64,112,112][]so[], mem_prop=0) | /fake_quantizer_7c435f9c02917de57484db91f86bbbaf/QuantizeLinear_output_0'.1-(fp8[16,3,224,224][]so[], mem_prop=0), __mye150_dconst-{-2.75, -1.625, -0.46875, 20, 15, 4.5, -3.25, 3, ...}(fp8[64,3,7,7][147,49,7,1]so[3,2,1,0], mem_prop=0)<entry>, __mye126-1.10786e-05F:(f32[][]so[], mem_prop=0)<entry>, __mye91/L__self___conv1/Conv_beta-0F:(f32[][]so[], mem_prop=0)<entry>, stream = 0 // __mye130_conv
         | n_groups: 1  lpad: {3, 3}  rpad: {3, 3}  pad_mode: 0 strides: {2, 2}  dilations: {1, 1}
[07/12/2024-05:24:18] [V] [TRT] {ForeignNode[/fake_quantizer_7c435f9c02917de57484db91f86bbbaf/QuantizeLinear.../fake_quantizer_5d22154b57438817000ba0a1ea6159ca/DequantizeLinear]} (Myelin[0x80000023]) profiling completed in 0.118439 seconds. Fastest Tactic: 0xd15ea5edd15ea5ed Time: inf
[07/12/2024-05:24:18] [E] Error[10]: IBuilder::buildSerializedNetwork: Error Code 10: Internal Error (Could not find any implementation for node {ForeignNode[/fake_quantizer_7c435f9c02917de57484db91f86bbbaf/QuantizeLinear.../fake_quantizer_5d22154b57438817000ba0a1ea6159ca/DequantizeLinear]}.)
[07/12/2024-05:24:18] [E] Engine could not be created from network
[07/12/2024-05:24:18] [E] Building engine failed
[07/12/2024-05:24:18] [E] Failed to create engine from model or file.
[07/12/2024-05:24:18] [E] Engine set up failed
&&&& FAILED TensorRT.trtexec [TensorRT v100200] # trtexec --onnx=ResNet18_batch16_fp8.onnx --fp16 --fp8 --stronglyTyped --verbose --profilingVerbosity=detailed
nvpohanh commented 2 months ago

Did you insert the Q/DQ ops by using the TensorRT Model Optimizer toolkit? https://github.com/NVIDIA/TensorRT-Model-Optimizer/tree/main/onnx_ptq

It should have avoided the Q/DQ ops before Convs whose C and K are not multiples of 16.

nvpohanh commented 2 months ago

But thanks for pointing this out. I will add this limitation to our release notes.

junstar92 commented 2 months ago

Did you insert the Q/DQ ops by using the TensorRT Model Optimizer toolkit? https://github.com/NVIDIA/TensorRT-Model-Optimizer/tree/main/onnx_ptq

It should have avoided the Q/DQ ops before Convs whose C and K are not multiples of 16.

@nvpohanh Thanks for quick answer. I inserted Q/DQ ops by using modelopt. But, in case of native Q/DQ ops, same error occurs.

in case of using modelopt. image in case of native Q/DQ ops. image

nvpohanh commented 2 months ago

Filed an internal tracker: id 4744383

We will debug this and find out how this is different from our FP8 ResNet50 testing in our CI/CD.

junstar92 commented 2 months ago

This is my quantization and onnx-export code.

import torch
import torchvision
import modelopt.torch.quantization as mtq

FP8_DEFAULT_CFG = {
    "quant_cfg": {
        "*weight_quantizer": {"num_bits": (4, 3), "axis": None},
        "*input_quantizer": {"num_bits": (4, 3), "axis": None},
        "*output_quantizer": {"enable": False},
        "*block_sparse_moe.gate*": {"enable": False},  # Skip the MOE router
        "default": {"num_bits": (4, 3), "axis": None},
    },
    "algorithm": "max",
}

def calib_loop():
    for _ in range(10):
        model(torch.randn(16, 3, 224, 224, device='cuda'))

model = torchvision.models.resnet18(pretrained=True).cuda()
mtq.quantize(model, FP8_DEFAULT_CFG, forward_loop=calib_loop)
torch.onnx.export(
    model,
    torch.randn(16, 3, 224, 224, device='cuda'),
    'resnet18_fp8.onnx',
    input_names=['input'],
    output_names=['output'],
)
nvpohanh commented 2 months ago

@junstar92 Oh I see, you are using modelopt.torch.quantization while I was referring to modelopt.onnx.quantization. Could you first export the original model to ONNX and then use modelopt.onnx.quantization to add Q/DQ nodes?

I will check internally about modelopt.torch.quantization vs modelopt.onnx.quantization differences.

junstar92 commented 2 months ago

@nvpohanh Okay, it is the quantized onnx by using modelopt.onnx.quantization, and the first convolution is not quantized. image

It succeeded to build this onnx model. It seems right that the first conv op is not implemented.

junstar92 commented 2 months ago

@nvpohanh My question has been resolved and I close this issue. I appreciate your support.

lishicheng1996 commented 1 month ago

@junstar92 May you please show the version of ModelOpt and Pytorch?

I'm tring fp8 Resnet too. Thank you very much!!!

junstar92 commented 1 month ago

@lishicheng1996 PyTorch: 2.2 ModelOpt: 0.13

nvpohanh commented 1 month ago

We will fix this in the TRT 10.4 release.