NVIDIA / TensorRT

NVIDIA® TensorRT™ is an SDK for high-performance deep learning inference on NVIDIA GPUs. This repository contains the open source components of TensorRT.
https://developer.nvidia.com/tensorrt
Apache License 2.0
10.55k stars 2.1k forks source link

Inconsistencies in output when running inference on DeBERTa with disentangled attention plugin and dynamic sequence length for input #4005

Open jkim695 opened 2 months ago

jkim695 commented 2 months ago

Description

When running inference with TensorRT's disentangled attention plugin on Microsoft's implementation of DeBERTa , I noticed that I get inconsistent output when running with dynamic sequence length in my inputs. This can be reproduced by using various input sequence lengths that are less than the max set for the optimization profile in the created TensorRT engine:

Running inference on engine ./test/deberta_plugin_V100_fp16.engine
Iteration 1: [[-0.015625, 0.0750732421875]]
Iteration 2: [[-0.015167236328125, 0.069580078125]]
Iteration 3: [[-0.01617431640625, 0.07696533203125]]
Iteration 4: [[-0.01483154296875, 0.08099365234375]]
Iteration 5: [[-0.0129241943359375, 0.0771484375]]
Iteration 6: [[-0.0179901123046875, 0.081298828125]]
Iteration 7: [[-0.0145416259765625, 0.0816650390625]]
Iteration 8: [[-0.0194091796875, 0.0721435546875]]
Iteration 9: [[-0.02618408203125, 0.06939697265625]]
Iteration 10: [[-0.0210113525390625, 0.07452392578125]]

I get consistent output when running inference on the model without the plugin with the same script :

Running inference on engine ./test/deberta_original_V100_fp16.engine
Iteration 1: [[-0.2298583984375, 0.19677734375]]
Iteration 2: [[-0.2298583984375, 0.19677734375]]
Iteration 3: [[-0.2298583984375, 0.19677734375]]
Iteration 4: [[-0.2298583984375, 0.19677734375]]
Iteration 5: [[-0.2298583984375, 0.19677734375]]
Iteration 6: [[-0.2298583984375, 0.19677734375]]
Iteration 7: [[-0.2298583984375, 0.19677734375]]
Iteration 8: [[-0.2298583984375, 0.19677734375]]
Iteration 9: [[-0.2298583984375, 0.19677734375]]
Iteration 10: [[-0.2298583984375, 0.19677734375]]

These results may suggest a bug with the optimization profile or the disentangled attention plugin.

Environment

TensorRT Version: 10.2

NVIDIA GPU: Tesla V100

NVIDIA Driver Version: 535.171.04

CUDA Version: 12.4

CUDNN Version: N/A

Operating System: ubuntu 20.04

Python Version (if applicable):3.8.1

Tensorflow Version (if applicable): N/A

PyTorch Version (if applicable): 1.11

Baremetal or Container (if so, version): N/A

Relevant Files

Model link: https://huggingface.co/microsoft/deberta-v3-xsmall, pulled from transformers repository version 4.22.0

Steps To Reproduce

import torch
import tensorrt as trt
import os, sys, argparse
import numpy as np
import pycuda.driver as cuda
import pycuda.autoinit 
from transformers import DebertaV2Tokenizer, DebertaV2Config, DebertaV2ForSequenceClassification
import onnx
import onnx_graphsurgeon as gs
from onnx import TensorProto

seq_len = 4059

def export():
    parent_dir = os.path.dirname('./test/deberta.onnx')
    if not os.path.exists(parent_dir):
        os.makedirs(parent_dir)
    deberta_model = DebertaV2ForSequenceClassification.from_pretrained('microsoft/deberta-v3-xsmall')
    batch_size = 1
    seq_len = 512
    vocab_size = 128100
    deberta_model.cuda().eval()
    gpu = torch.device('cuda')
    input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), dtype=torch.long, device=gpu)
    attention_mask = torch.randint(0, 2, (batch_size, seq_len), dtype=torch.long, device=gpu)
    input_names = ['input_ids', 'attention_mask']
    output_names = ['output']

    ## Set seq_len to dynamic
    dynamic_axes={'input_ids'   : {0 : 'batch_size', 1 : 'seq_len'},
                  'attention_mask'   : {0 : 'batch_size', 1 : 'seq_len'},
                  'output' : {0 : 'batch_size'}}

    ## Export model to onnx
    torch.onnx.export(deberta_model,
                     (input_ids, attention_mask),
                     './test/deberta.onnx',
                     export_params=True,
                     opset_version=13,
                     do_constant_folding=True,
                     input_names = input_names,
                     output_names = output_names,
                     dynamic_axes = dynamic_axes)

def remove_uint8_cast(graph):
    nodes = [node for node in graph.nodes if node.op == 'Cast' and node.attrs["to"] == TensorProto.UINT8]
    for node in nodes:
        input_node = node.i()
        input_node.outputs = node.outputs
        node.outputs.clear()
    return graph

@gs.Graph.register()
def insert_disentangled_attention(self, inputs, outputs, factor, span):
    [out.inputs.clear() for out in outputs]
    attrs = {
        "factor": 1/factor,
        "span": span
    }
    self.layer(op='DisentangledAttention_TRT', inputs=inputs, outputs=outputs, attrs=attrs)

def insert_disentangled_attention_all(graph):
    nodes = [node for node in graph.nodes if node.op == 'GatherElements']
    assert len(nodes) % 2 == 0, "No. of GatherElements nodes is not an even number!"
    layers = [(nodes[2*i+0], nodes[2*i+1]) for i in range(len(nodes)//2)]
    for l, (left,right) in enumerate(layers):
        print(f"Fusing layer {l}")
        inputs = list(left.o().o().o().o().i().inputs)[0:1] + list(left.inputs)[0:1] + list(right.inputs)[0:1]
        outputs = list(left.o().o().o().o().outputs)
        factor = 13.856406211853027
        span = 256
        graph.insert_disentangled_attention(inputs, outputs, factor, span)
    return graph

def insert_plugin():
    graph = gs.import_onnx(onnx.load('./test/deberta.onnx'))
    graph = remove_uint8_cast(graph)
    graph = insert_disentangled_attention_all(graph)
    graph.cleanup().toposort()
    onnx.save_model(gs.export_onnx(graph), './test/deberta_plugin.onnx')
    print(f"Saving modified model to ./test/deberta_plugin.onnx")

class TRTModel:
    class HostDeviceMem(object):
        def __init__(self, host_mem, device_mem):
            self.host = host_mem
            self.device = device_mem

    def __init__(self, engine_path):
        self.engine_path = engine_path
        self.logger = trt.Logger(trt.Logger.WARNING)
        self.runtime = trt.Runtime(self.logger)
        self.engine = self.load_engine()
        self.inputs, self.outputs, self.bindings, self.stream = self.allocate_buffers(self.engine)
        self.context = self.engine.create_execution_context()
        self.numpy_to_torch_dtype_dict = {
            np.int64      : torch.int64,
        }

    def load_engine(self):
        with open(self.engine_path, 'rb') as f:
            engine = self.runtime.deserialize_cuda_engine(f.read())
        return engine

    def allocate_buffers(self, engine):
        inputs = []
        outputs = []
        bindings = []
        stream = cuda.Stream()
        for i in range(engine.num_io_tensors):
            tensor_name = engine.get_tensor_name(i)
            size = trt.volume(engine.get_tensor_shape(tensor_name))
            dtype = trt.nptype(engine.get_tensor_dtype(tensor_name))
            if size == -1:
                size = 4059
            host_mem = cuda.pagelocked_empty(size, dtype) 
            device_mem = cuda.mem_alloc(host_mem.nbytes)
            bindings.append(int(device_mem))
            if engine.get_tensor_mode(tensor_name) == trt.TensorIOMode.INPUT:
                inputs.append(self.HostDeviceMem(host_mem, device_mem))
            else:
                outputs.append(self.HostDeviceMem(host_mem, device_mem))
        return inputs, outputs, bindings, stream

    def __call__(self, model_inputs: list, timing=False):
        batch_size = np.unique(np.array([i.size(dim=0) for i in model_inputs]))
        batch_size = batch_size[0]
        for i, model_input in enumerate(model_inputs):
            binding_name = self.engine.get_tensor_name(i)
            binding_dtype = trt.nptype(self.engine.get_tensor_dtype(binding_name)) 
            model_input = model_input.to(self.numpy_to_torch_dtype_dict[binding_dtype])
            cuda.memcpy_dtod_async(self.inputs[i].device, model_input.data_ptr(), model_input.element_size() * model_input.nelement(), self.stream)
        for i in range(self.engine.num_io_tensors):
            self.context.set_tensor_address(self.engine.get_tensor_name(i), self.bindings[i])
        self.context.execute_async_v3(stream_handle=self.stream.handle)
        [cuda.memcpy_dtoh_async(out.host, out.device, self.stream) for out in self.outputs]
        self.stream.synchronize()
        return [torch.from_numpy(out.host.reshape(batch_size,-1)) for out in self.outputs]

def build_engine():
    TRT_LOGGER = trt.Logger(trt.Logger.INFO)
    TRT_BUILDER = trt.Builder(TRT_LOGGER)
    engine_filename = '_'.join(['./test/deberta_plugin', 'V100', 'fp16']) + '.engine'
    if "EXPLICIT_BATCH" in trt.NetworkDefinitionCreationFlag.__members__.keys():
        network_creation_flag = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
    network = TRT_BUILDER.create_network(0)
    onnx_parser = trt.OnnxParser(network, TRT_LOGGER)
    onnx_parser.parse_from_file('./test/deberta_plugin.onnx')
    config = TRT_BUILDER.create_builder_config()
    profile = TRT_BUILDER.create_optimization_profile()
    profile.set_shape("input_ids", (1,1), (1,seq_len), (1,4096))
    profile.set_shape("attention_mask", (1,1), (1,seq_len), (1,4096))
    config.add_optimization_profile(profile)
    config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 4096 * (1 << 20))
    config.set_flag(trt.BuilderFlag.FP16)
    serialized_engine = TRT_BUILDER.build_serialized_network(network, config)
    with open(engine_filename, 'wb') as f:
        f.write(serialized_engine)

def test_engine():
    torch.manual_seed(42)
    engine_filename = '_'.join(['./test/deberta_plugin', 'V100', 'fp16']) + '.engine'
    model = TRTModel(engine_filename)
    batch_size = 1
    vocab = 128203
    gpu = torch.device('cuda')
    input_ids = torch.randint(0, vocab, (batch_size, seq_len), dtype=torch.long, device=gpu)
    attention_mask = torch.randint(0, 2, (batch_size, seq_len), dtype=torch.long, device=gpu)
    inputs = [input_ids, attention_mask]
    model.context.set_input_shape("input_ids", (1, seq_len))
    model.context.set_input_shape("attention_mask", (1, seq_len))
    outputs = model(inputs)
    nreps = 10
    for i in range(nreps):
        outputs = model(inputs)
        print(f'Iteration {i + 1}: {outputs[0].tolist()}')

if __name__ == "__main__":
    export()
    insert_plugin()
    build_engine()
    test_engine()

Have you tried the latest release?: Yes

Can this model run on other frameworks? For example run ONNX model with ONNXRuntime (polygraphy run <model.onnx> --onnxrt): Yes, I've tested it with onnxrt with cudnn 8.9 and CUDA 11.8, still had inconsistent output when running with plugin.

lix19937 commented 1 month ago

You can check your result before the plugin node ?

jkim695 commented 1 month ago

I inserted output nodes for the inputs of the plugin, and I found that one was producing inconsistent output (MatMul_173): Screenshot 2024-08-12 131328

When checking the inputs and outputs for this node, the inputs (query_layer & MatMul_423) are consistent across multiple runs, but the output is not.

Iteration 1:
query_layer: tensor([ 0.6668, -0.0100,  0.7874,  ..., -0.0217, -0.1060,  0.1076])
MatMul_423: tensor([ 1.4557, -1.4124, -1.1299,  ...,  1.1094,  1.0816,  1.1094])
Div_424: tensor([-0.8541, -1.2864, -0.9897, -1.1466, -1.5251, -1.0459])

Iteration 2: 
query_layer: tensor([ 0.6668, -0.0100,  0.7874,  ..., -0.0217, -0.1060,  0.1076])
MatMul_423: tensor([ 1.4557, -1.4124, -1.1299,  ...,  1.1094,  1.0816,  1.1094])
Div_424: tensor([[-0.5895, -1.2377, -1.0294, -1.1611, -1.2238, -1.1667]])

Iteration 3:
query_layer: tensor([ 0.6668, -0.0100,  0.7874,  ..., -0.0217, -0.1060,  0.1076])
MatMul_423: tensor([ 1.4557, -1.4124, -1.1299,  ...,  1.1094,  1.0816,  1.1094])
Div_424: tensor([[-0.4970, -1.1129, -0.8866, -0.9455, -1.3411, -0.9709]])

I also observed that this node produced consistent output when running with the model without the plugin: Screenshot 2024-08-12 134740

Iteration 1:
query_layer: tensor([ 0.6680, -0.0111,  0.7866,  ..., -0.0217, -0.1060,  0.1076])
Mat_mul 423: tensor([ 1.4551, -1.4111, -1.1299,  ...,  1.1094,  1.0820,  1.1094])
Div_424: tensor([[-0.5854, -1.2207, -1.0205, -1.0791, -0.9033, -0.9243]])

Iteration 2:
query_layer: tensor([ 0.6680, -0.0111,  0.7866,  ..., -0.0217, -0.1060,  0.1076])
Mat_mul 423: tensor([ 1.4551, -1.4111, -1.1299,  ...,  1.1094,  1.0820,  1.1094])
Div_424: tensor([[-0.5854, -1.2207, -1.0205, -1.0791, -0.9033, -0.9243]])

Iteration 3:
query_layer: tensor([ 0.6680, -0.0111,  0.7866,  ..., -0.0217, -0.1060,  0.1076])
Mat_mul 423: tensor([ 1.4551, -1.4111, -1.1299,  ...,  1.1094,  1.0820,  1.1094])
Div_424: tensor([[-0.5854, -1.2207, -1.0205, -1.0791, -0.9033, -0.9243]])

Is there an explanation for why inserting the plugin node causes inconsistencies in this MatMul node?