ELS-RD / transformer-deploy

Efficient, scalable and enterprise-grade CPU/GPU inference server for 🤗 Hugging Face transformer models 🚀
https://els-rd.github.io/transformer-deploy/
Apache License 2.0
1.66k stars 151 forks source link

Speed difference ONNX vs TensorRT with samples sorted by sequence length #55

Open v1nc3nt27 opened 2 years ago

v1nc3nt27 commented 2 years ago

I noticed something unexpected when comparing two scenarios for a model converted via ONNX and TensorRT (distilroberta with classification head):

  1. Scenario: I use a dataset with varying sentence lengths (~20-60 tokens) and run it randomly sampled through both models
  2. Scenario: I use the same dataset but sort the sentences by sentence length (decreasing) before running it through both models

Result: The TensorRT model does not seem to care about the sequence lengths and keeps the same speed for both scenarios. The ONNX model, however, gets almost twice as fast when I use the second scenario.

I was wondering if tensorRT's optimization does somehow require to pad to the max length internally. I was searching for a parameter or a reason for this behavior but couldn't find anything useful. For conversion, I set the seq-len parameter to 1 60 60.

I was wondering if perhaps someone else has already observed this and knows the reason / a solution.

pommedeterresautee commented 2 years ago

Is there some batching applied?

v1nc3nt27 commented 2 years ago

Oh, I completely forgot to mention that. Yes, I use a batch size of 64. This behavior only applies if batching is used.

pommedeterresautee commented 2 years ago

how each batch is built? is it made of seq of the exact same len ?

v1nc3nt27 commented 2 years ago

The samples are just ordered by character length and then batched, so they still may vary within a batch (but much less than before). The speed up just comes from the fact less batches are padded to the model_max_length in that case.

I replaced

https://github.com/ELS-RD/transformer-deploy/blob/1f2d2c1d8d0239fca7679f8c550a954ea1445cfa/src/transformer_deploy/utils/python_tokenizer.py#L58

with

tokens: Dict[str, np.ndarray] = self.tokenizer(query_question, query_answer return_tensors=TensorType.NUMPY, padding="longest", truncation=True)

and added self.tokenizer.model_max_length = 60 as a last line to initialize().

pommedeterresautee commented 2 years ago

can you provide me with some reproducible code so I test on my side?

v1nc3nt27 commented 2 years ago

Hey @pommedeterresautee, sorry for the long wait - I was on a holiday trip.

I based my script on your demo scripts but I cannot disclose the model and/or dataset. You can basically use any dataset with 2 inputs, e.g. example for QA. I hope you can make use of it anyway.

I attached the script to call the inference assemble hosted in triton (transformer_onnx_inference or transformer_trt_inference) and the slightly modified model.py for the tokenize endpoint in triton.

If you experience the same what I do, then calling the ONNX model's inference endpoint should be slower if you comment out the length sorting in triton_inference_qa_test.py and there should be no difference if you do the same for the trt model's inference.

triton_inference_qa_test.py

import argparse
import math
import time
import numpy as np
import tritonclient.http
from tqdm import tqdm
from scipy.special import softmax
from transformer_deploy.benchmarks.utils import print_timings, setup_logging, track_infer_time

def _batch(iterable, n=1):
    l = len(iterable)
    for ndx in range(0, l, n):
        yield iterable[ndx:min(ndx + n, l)]

def run_http_sync():
    all_gold = []
    all_pred = []
    triton_client = tritonclient.http.InferenceServerClient(url=url, verbose=False)

    assert triton_client.is_model_ready(
        model_name=model_name, model_version=model_version
    ), f"model {model_name} not yet ready"
    setup_logging()

    model_score = tritonclient.http.InferRequestedOutput(name="output", binary_data=True)
    for b in tqdm(_batch(list(zip(questions, answers, golds)), batch_size), total=math.ceil(len(answers)/batch_size)):
        with track_infer_time(time_buffer):
            topic_b, sent_b, gold_b = zip(*b) 
            all_gold.extend(gold_b)

            query_sent = tritonclient.http.InferInput(name="sent", shape=(len(b),), datatype="BYTES")
            query_topic = tritonclient.http.InferInput(name="topic", shape=(len(b),), datatype="BYTES")

            query_sent.set_data_from_numpy(np.asarray(sent_b, dtype=object))
            query_topic.set_data_from_numpy(np.asarray(topic_b, dtype=object))

            response = triton_client.infer(
                model_name=model_name, model_version=model_version, inputs=[query_topic, query_sent], outputs=[model_score],
                response_compression_algorithm="gzip", request_compression_algorithm="gzip"
            )
            res = response.as_numpy("output")
            scores = softmax(res, axis=1)
            all_pred.extend([np.argmax(pred) for pred in scores])

    return all_gold, all_pred

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="require inference", formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    parser.add_argument("--triton-model-name", help="Model name in triton server", type=str)
    parser.add_argument("--url", help="", type=str, default="127.0.0.1:8000")
    parser.add_argument("--batch-size", help="", type=int, default=64)
    args, _ = parser.parse_known_args()

    setup_logging()
    model_name = args.triton_model_name
    url = args.url
    model_version = "1"
    batch_size = args.batch_size

    # todo read data
    data = read_data(...)

    answers = data["answers"].tolist()
    golds = data["label"].tolist()
    questions = data["questions"].tolist()
    t = time.time()

    # comment out the following 4 lines to switch off bucketing
    length_sorted_idx = np.argsort([len(q+a) for q, a in zip(questions, answers)])
    answers = [answers[idx] for idx in length_sorted_idx]
    golds = [golds[idx] for idx in length_sorted_idx]
    questions = [questions[idx] for idx in length_sorted_idx]
    time_buffer = list()

    all_gold, all_pred = run_http_sync()

    print_timings(name="triton transformers", timings=time_buffer)
    total_time = time.time() - t
    print("Total time: " + str(total_time))

model.py

import os
from typing import Dict, List

import numpy as np

try:
    # noinspection PyUnresolvedReferences
    import triton_python_backend_utils as pb_utils
except ImportError:
    pass  # triton_python_backend_utils exists only inside Triton Python backend.

from transformers import AutoTokenizer, PreTrainedTokenizer, TensorType

class TritonPythonModel:
    is_tensorrt: bool
    tokenizer: PreTrainedTokenizer

    def initialize(self, args: Dict[str, str]) -> None:
        """
        Initialize the tokenization process
        :param args: arguments from Triton config file
        """
        path: str = os.path.join(args["model_repository"], args["model_version"])
        model_name: str = args["model_name"]
        self.is_tensorrt = "trt" in model_name or "tensorrt" in model_name
        self.tokenizer = AutoTokenizer.from_pretrained(path)
        self.tokenizer.model_max_length = 60

    def execute(self, requests) -> "List[List[pb_utils.Tensor]]":
        """
        Parse and tokenize each request
        :param requests: 1 or more requests received by Triton server.
        :return: text as input tensors
        """
        responses = []
        # for loop for batch requests (disabled in our case)
        for request in requests:
            # binary data typed back to string
            query_topic = [t.decode("UTF-8") for t in pb_utils.get_input_tensor_by_name(request, "topic").as_numpy().tolist()]
            query_sent = [t.decode("UTF-8") for t in pb_utils.get_input_tensor_by_name(request, "sent").as_numpy().tolist()]

            tokens: Dict[str, np.ndarray] = self.tokenizer(query_topic, query_sent, return_tensors=TensorType.NUMPY,
                                                           padding="longest", truncation=True)
            if self.is_tensorrt:
                # tensorrt uses int32 as input type, ort uses int64
                tokens = {k: v.astype(np.int32) for k, v in tokens.items()}
            # communicate the tokenization results to Triton server
            outputs = list()
            for input_name in self.tokenizer.model_input_names:
                tensor_input = pb_utils.Tensor(input_name, tokens[input_name])
                outputs.append(tensor_input)

            inference_response = pb_utils.InferenceResponse(output_tensors=outputs)
            responses.append(inference_response)

        return responses