NVIDIA / TensorRT

NVIDIA® TensorRT™ is an SDK for high-performance deep learning inference on NVIDIA GPUs. This repository contains the open source components of TensorRT.
https://developer.nvidia.com/tensorrt
Apache License 2.0
10.72k stars 2.12k forks source link

Issue to build `deberta-v3-base` due to missing `validUnaryType &&` datatype on `TensorRT 8.6` #3587

Closed VibhuJawa closed 8 months ago

VibhuJawa commented 9 months ago

Description

I am running into conversion issue while trying to convert deberta-v3-base into a TensorRT engine. We run into

 UNSUPPORTED_NODE: Assertion failed: validUnaryType && "This version of TensorRT does not support the given operator with the given input data type."

More Trace is present next to MRE.

Environment

TensorRT Version:

8.6

NVIDIA GPU:

V100

NVIDIA Driver Version: 525.105.17

CUDA Version: 12.0

CUDNN Version:

Operating System:

Python Version (if applicable): 3.10

Steps To Reproduce

from transformers import DebertaV2ForSequenceClassification
import tensorrt as trt
import torch

def remove_uint8_cast(graph):
    nodes = [node for node in graph.nodes if node.op == 'Cast' and node.attrs["to"] == TensorProto.UINT8]

    for node in nodes:
        input_node = node.i()
        input_node.outputs = node.outputs
        node.outputs.clear()

    return graph

deberta_model = DebertaV2ForSequenceClassification.from_pretrained("microsoft/deberta-v3-base").cuda()
deberta_model.eval()
vocab_size = deberta_model.config.vocab_size
batch_size = 32
seq_len = 12
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), dtype=torch.long).to('cuda')
attention_mask = torch.randint(0, 2, (batch_size, seq_len), dtype=torch.long).to('cuda')
input_names = ['input_ids', 'attention_mask']
output_names = ['output']
dynamic_axes={'input_ids'   : {0 : 'batch_size', 1: 'seq_len'},
              'attention_mask'   : {0 : 'batch_size', 1: 'seq_len'},
              'output' : {0 : 'batch_size'}}

torch.onnx.export(deberta_model,
                  (input_ids, attention_mask),
                  "model.onnx",
                  export_params=True,
                  opset_version=13,
                  do_constant_folding=True,
                  input_names = input_names,
                  output_names = output_names,
                  dynamic_axes = dynamic_axes
                 )

TRT_LOGGER = trt.Logger(trt.Logger.INFO)
TRT_BUILDER = trt.Builder(TRT_LOGGER)
network = TRT_BUILDER.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
onnx_parser = trt.OnnxParser(network, TRT_LOGGER)
parse_success = onnx_parser.parse_from_file("model.onnx")

for idx in range(onnx_parser.num_errors):
    print(onnx_parser.get_error(idx))
Some weights of DebertaV2ForSequenceClassification were not initialized from the model checkpoint at microsoft/deberta-v3-base and are newly initialized: ['classifier.bias', 'classifier.weight', 'pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
[/datasets/vjawa/miniconda3/envs/TensorRT_Transformers/lib/python3.10/site-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:554](https://vscode-remote+ssh-002dremote-002b10-002e33-002e227-002e161.vscode-resource.vscode-cdn.net/datasets/vjawa/miniconda3/envs/TensorRT_Transformers/lib/python3.10/site-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:554): TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  torch.tensor(mid - 1).type_as(relative_pos),
[/datasets/vjawa/miniconda3/envs/TensorRT_Transformers/lib/python3.10/site-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:558](https://vscode-remote+ssh-002dremote-002b10-002e33-002e227-002e161.vscode-resource.vscode-cdn.net/datasets/vjawa/miniconda3/envs/TensorRT_Transformers/lib/python3.10/site-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:558): TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  torch.ceil(torch.log(abs_pos [/](https://vscode-remote+ssh-002dremote-002b10-002e33-002e227-002e161.vscode-resource.vscode-cdn.net/) mid) [/](https://vscode-remote+ssh-002dremote-002b10-002e33-002e227-002e161.vscode-resource.vscode-cdn.net/) torch.log(torch.tensor((max_position - 1) [/](https://vscode-remote+ssh-002dremote-002b10-002e33-002e227-002e161.vscode-resource.vscode-cdn.net/) mid)) * (mid - 1)) + mid
[/datasets/vjawa/miniconda3/envs/TensorRT_Transformers/lib/python3.10/site-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:717](https://vscode-remote+ssh-002dremote-002b10-002e33-002e227-002e161.vscode-resource.vscode-cdn.net/datasets/vjawa/miniconda3/envs/TensorRT_Transformers/lib/python3.10/site-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:717): TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  scale = torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor)
[/datasets/vjawa/miniconda3/envs/TensorRT_Transformers/lib/python3.10/site-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:717](https://vscode-remote+ssh-002dremote-002b10-002e33-002e227-002e161.vscode-resource.vscode-cdn.net/datasets/vjawa/miniconda3/envs/TensorRT_Transformers/lib/python3.10/site-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:717): UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  scale = torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor)
[/datasets/vjawa/miniconda3/envs/TensorRT_Transformers/lib/python3.10/site-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:792](https://vscode-remote+ssh-002dremote-002b10-002e33-002e227-002e161.vscode-resource.vscode-cdn.net/datasets/vjawa/miniconda3/envs/TensorRT_Transformers/lib/python3.10/site-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:792): TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  scale = torch.sqrt(torch.tensor(pos_key_layer.size(-1), dtype=torch.float) * scale_factor)
[/datasets/vjawa/miniconda3/envs/TensorRT_Transformers/lib/python3.10/site-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:792](https://vscode-remote+ssh-002dremote-002b10-002e33-002e227-002e161.vscode-resource.vscode-cdn.net/datasets/vjawa/miniconda3/envs/TensorRT_Transformers/lib/python3.10/site-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:792): UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  scale = torch.sqrt(torch.tensor(pos_key_layer.size(-1), dtype=torch.float) * scale_factor)
[/datasets/vjawa/miniconda3/envs/TensorRT_Transformers/lib/python3.10/site-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:804](https://vscode-remote+ssh-002dremote-002b10-002e33-002e227-002e161.vscode-resource.vscode-cdn.net/datasets/vjawa/miniconda3/envs/TensorRT_Transformers/lib/python3.10/site-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:804): TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  scale = torch.sqrt(torch.tensor(pos_query_layer.size(-1), dtype=torch.float) * scale_factor)
[/datasets/vjawa/miniconda3/envs/TensorRT_Transformers/lib/python3.10/site-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:804](https://vscode-remote+ssh-002dremote-002b10-002e33-002e227-002e161.vscode-resource.vscode-cdn.net/datasets/vjawa/miniconda3/envs/TensorRT_Transformers/lib/python3.10/site-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:804): UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  scale = torch.sqrt(torch.tensor(pos_query_layer.size(-1), dtype=torch.float) * scale_factor)
[/datasets/vjawa/miniconda3/envs/TensorRT_Transformers/lib/python3.10/site-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:805](https://vscode-remote+ssh-002dremote-002b10-002e33-002e227-002e161.vscode-resource.vscode-cdn.net/datasets/vjawa/miniconda3/envs/TensorRT_Transformers/lib/python3.10/site-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:805): TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if key_layer.size(-2) != query_layer.size(-2):
[/datasets/vjawa/miniconda3/envs/TensorRT_Transformers/lib/python3.10/site-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:112](https://vscode-remote+ssh-002dremote-002b10-002e33-002e227-002e161.vscode-resource.vscode-cdn.net/datasets/vjawa/miniconda3/envs/TensorRT_Transformers/lib/python3.10/site-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:112): TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  output = input.masked_fill(rmask, torch.tensor(torch.finfo(input.dtype).min))
[01/10/2024-08:55:53] [TRT] [I] The logger passed into createInferBuilder differs from one already provided for an existing builder, runtime, or refitter. Uses of the global logger, returned by nvinfer1::getLogger(), will return the existing value.
[01/10/2024-08:55:53] [TRT] [I] [MemUsageChange] Init CUDA: CPU +0, GPU +0, now: CPU 3739, GPU 3915 (MiB)
[libprotobuf WARNING google/protobuf/io/coded_stream.cc:604] Reading dangerously large protocol message.  If the message turns out to be larger than 2147483647 bytes, parsing will be halted for security reasons.  To increase the limit (or to disable these warnings), see CodedInputStream::SetTotalBytesLimit() in google/protobuf/io/coded_stream.h.
[libprotobuf WARNING google/protobuf/io/coded_stream.cc:81] The total number of bytes read was 738637429
[01/10/2024-08:55:54] [TRT] [I] ----------------------------------------------------------------
[01/10/2024-08:55:54] [TRT] [I] Input filename:   model.onnx
[01/10/2024-08:55:54] [TRT] [I] ONNX IR version:  0.0.7
[01/10/2024-08:55:54] [TRT] [I] Opset version:    13
[01/10/2024-08:55:54] [TRT] [I] Producer name:    pytorch
[01/10/2024-08:55:54] [TRT] [I] Producer version: 2.1.0
[01/10/2024-08:55:54] [TRT] [I] Domain:           
[01/10/2024-08:55:54] [TRT] [I] Model version:    0
[01/10/2024-08:55:54] [TRT] [I] Doc string:       
[01/10/2024-08:55:54] [TRT] [I] ----------------------------------------------------------------
[01/10/2024-08:55:55] [TRT] [W] onnx2trt_utils.cpp:374: Your ONNX model has been generated with INT64 weights, while TensorRT does not natively support INT64. Attempting to cast down to INT32.
[01/10/2024-08:55:55] [TRT] [E] ModelImporter.cpp:771: While parsing node number 66 [Sign -> "/deberta/encoder/Sign_output_0"]:
[01/10/2024-08:55:55] [TRT] [E] ModelImporter.cpp:772: --- Begin node ---
[01/10/2024-08:55:55] [TRT] [E] ModelImporter.cpp:773: input: "/deberta/encoder/Sub_output_0"
output: "/deberta/encoder/Sign_output_0"
name: "/deberta/encoder/Sign"
op_type: "Sign"

[01/10/2024-08:55:55] [TRT] [E] ModelImporter.cpp:774: --- End node ---
[01/10/2024-08:55:55] [TRT] [E] ModelImporter.cpp:777: ERROR: onnx2trt_utils.cpp:1779 In function unaryHelper:
[8] Assertion failed: validUnaryType && "This version of TensorRT does not support the given operator with the given input data type."
In node 66 (unaryHelper): UNSUPPORTED_NODE: Assertion failed: validUnaryType && "This version of TensorRT does not support the given operator with the given input data type."
[libprotobuf WARNING google/protobuf/io/coded_stream.cc:604] Reading dangerously large protocol message.  If the message turns out to be larger than 2147483647 bytes, parsing will be halted for security reasons.  To increase the limit (or to disable these warnings), see CodedInputStream::SetTotalBytesLimit() in google/protobuf/io/coded_stream.h.
[libprotobuf WARNING google/protobuf/io/coded_stream.cc:81] The total number of bytes read was 738637429

Also tried following https://github.com/NVIDIA/TensorRT/issues/3124 but to no avail.

# graph = gs.import_onnx(onnx.load("model.onnx"))
# graph = remove_uint8_cast(graph)

# graph.cleanup().toposort()
# onnx.save_model(gs.export_onnx(graph), "model_updated.onnx")

# model = fold_constants(onnx.load("model_updated.onnx"))
# onnx.save(model, "model_updated_folded.onnx")
zerollzeng commented 9 months ago

Looks like a known limitation, could you please try latest TRT 9.2? Thanks!

VibhuJawa commented 9 months ago

Looks like a known limitation, could you please try latest TRT 9.2? Thanks!

Can you link me to on how do I get access to latest TRT 9.2 , please?

I dont see nightly wheels.

VibhuJawa commented 9 months ago

I tested it with tensorrt==9.0.1.post12.dev4 and i can create the engine now.

I want to support dynamic batch and sequence sizes, I am running into below warnings which based on my understanding will mean that we will fail there. Can you suggest how to get that working ?

You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
/datasets/vjawa/miniconda3/envs/TensorRT_Transformers/lib/python3.10/site-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:554: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  torch.tensor(mid - 1).type_as(relative_pos),
/datasets/vjawa/miniconda3/envs/TensorRT_Transformers/lib/python3.10/site-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:558: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  torch.ceil(torch.log(abs_pos / mid) / torch.log(torch.tensor((max_position - 1) / mid)) * (mid - 1)) + mid
/datasets/vjawa/miniconda3/envs/TensorRT_Transformers/lib/python3.10/site-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:717: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  scale = torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor)
/datasets/vjawa/miniconda3/envs/TensorRT_Transformers/lib/python3.10/site-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:717: UserWarning: 
.................
  if key_layer.size(-2) != query_layer.size(-2):
/datasets/vjawa/miniconda3/envs/TensorRT_Transformers/lib/python3.10/site-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:112: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  output = input.masked_fill(rmask, torch.tensor(torch.finfo(input.dtype).min))
zerollzeng commented 9 months ago

it's a warning from transformers(pytorch), I think it means use dynamic shape may caused unexpected behavior when export onnx(e.g. some tensor still become constant) during export onnx. You can ask for help in the transformer repo.

ttyio commented 8 months ago

Closing since no activity for more than 3 weeks, thanks all!

copasseron commented 8 months ago

is there any release or tag of TensorRT 8.6.1 that resolves this issue ?

Because I could built it with TRT 9.2.0, but I want to deploy this model on nvidia triton inference server.

However, triton tensorRT backend does not yet support newer version of tensorRT than 8.6.1, and the tensorRT runtime version should be the same than the one used to build the engine.

What would be my best solution for this problem ?