triton-inference-server / server

The Triton Inference Server provides an optimized cloud and edge inferencing solution.
https://docs.nvidia.com/deeplearning/triton-inference-server/user-guide/docs/index.html
BSD 3-Clause "New" or "Revised" License
8.14k stars 1.46k forks source link

Hundred times slowndown of bert models #6084

Open guoners opened 1 year ago

guoners commented 1 year ago

Description I deployed a bert_base model from hugging face's transformer library via torchscript and Triton's pytorch backend. But i found the GPU utilization is around 0, and performance is far below the local test.

My Triton configuration and performance test environment is valid: i test with other torch backend model and the expected performance and GPU utilization were observed.

Triton Information Triton container: 23.01 from ngc

To Reproduce

I got model from transformer library and convert it to torchscript as below (model file). I deployed the model to Triton via http with following model config (model config).

In a rough test, the latency is 700ms and the RPS (request per second) is around 1.5; The GPU utilization is around 0. While the local test for the same .pt file (torchscript model engine) is around 8ms and 100 RPS with a 30% GPU utilization.

Through further observation, I found that the model actually used gpu, but for some reasons, the triton latency was very high (750ms), but the calculation time proportion of the model was very low (2ms), resulting in a low average utilization rate of GPU. What causes the gap between triton latency and compute time? I tested with torchscript, onnx, and tensorrt backend with same performance.

execution log:

I0724 03:18:54.323601 1 infer_request.cc:729] [request id: <id_unknown>] prepared: [0x0x7f13480610a0] request id: , model: bert_base, requested version: -1, actual version: 0, flags: 0x0, correlation id: 0, batch size: 1, priority: 0, timeout (us): 0
original inputs:
[0x0x7f134805b638] input: segments_ids, type: INT32, original shape: [1,14], batch + shape: [1,14], shape: [14]
[0x0x7f1348085618] input: indexed_tokens, type: INT32, original shape: [1,14], batch + shape: [1,14], shape: [14]
override inputs:
inputs:
[0x0x7f1348085618] input: indexed_tokens, type: INT32, original shape: [1,14], batch + shape: [1,14], shape: [14]
[0x0x7f134805b638] input: segments_ids, type: INT32, original shape: [1,14], batch + shape: [1,14], shape: [14]
original requested outputs:
logits
requested outputs:
logits

I0724 03:18:54.323703 1 tensorrt.cc:334] model bert_base, instance bert_base_0, executing 1 requests
I0724 03:18:54.323744 1 instance_state.cc:370] TRITONBACKEND_ModelExecute: Issuing bert_base_0 with 1 requests
I0724 03:18:54.323763 1 instance_state.cc:419] TRITONBACKEND_ModelExecute: Running bert_base_0 with 1 requests
I0724 03:18:54.323825 1 instance_state.cc:1464] Optimization profile default [0] is selected for bert_base_0
I0724 03:18:54.323870 1 pinned_memory_manager.cc:161] pinned memory allocation: size 56, addr 0x7f13ec000090
I0724 03:18:54.323937 1 pinned_memory_manager.cc:161] pinned memory allocation: size 56, addr 0x7f13ec0000d0
I0724 03:18:54.323971 1 instance_state.cc:910] Context with profile default [0] is being executed for bert_base_0
I0724 03:18:54.325730 1 infer_response.cc:167] add response output: output: logits, type: FP32, shape: [1,14,30522]
I0724 03:18:54.325769 1 http_server.cc:1098] HTTP: unable to provide 'logits' in GPU, will use CPU
I0724 03:18:54.325782 1 http_server.cc:1118] HTTP using buffer for: 'logits', size: 1709232, addr: 0x7f0bb8019090
I0724 03:18:54.325795 1 pinned_memory_manager.cc:161] pinned memory allocation: size 1709232, addr 0x7f13ec000110
I0724 03:18:55.036833 1 http_server.cc:1192] HTTP release: size 1709232, addr 0x7f0bb8019090
I0724 03:18:55.036914 1 instance_state.cc:1321] TRITONBACKEND_ModelExecute: model bert_base_0 released 1 requests
I0724 03:18:55.036947 1 pinned_memory_manager.cc:190] pinned memory deallocation: addr 0x7f13ec000110
I0724 03:18:55.036957 1 pinned_memory_manager.cc:190] pinned memory deallocation: addr 0x7f13ec000090
I0724 03:18:55.036973 1 pinned_memory_manager.cc:190] pinned memory deallocation: addr 0x7f13ec0000d0

The model file:

from transformers import BertModel, BertTokenizer, BertConfig, BertForMaskedLM
import torch

enc = BertTokenizer.from_pretrained("bert-base-uncased")
text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
tokenized_text = enc.tokenize(text)
masked_index = 8
tokenized_text[masked_index] = "[MASK]"
indexed_tokens = enc.convert_tokens_to_ids(tokenized_text)
segments_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]
tokens_tensor = torch.tensor([indexed_tokens])
segments_tensors = torch.tensor([segments_ids])

model = BertForMaskedLM.from_pretrained("bert-base-uncased")

class WrappedBertModel(torch.nn.Module):
    # consider a script model
    def __init__(self, bert_model):
        super().__init__()
        self.bert = bert_model

    def forward(self, t, s):
        # assert isinstance(x, torch.Tensor)
        # Batched Tensor to list of Tensor
        outs = self.bert(t, s) # model outs
        logits = outs['logits']
        return logits

wrap_model = WrappedBertModel(model).to('cuda:0')
traced_model = torch.jit.trace(wrap_model, [tokens_tensor.to('cuda:0'), 
                                            segments_tensors.to('cuda:0')])

_ = traced_model.to('cuda:0')
torch.jit.save(traced_model, "model_engines/traced_bert_base_mlm.pt")

The model config:

md_config = {
    "name": 'bert_base',
    "platform": 'pytorch_libtorch',
    "max_batch_size": 1,
    "input": [{
        "name": "indexed_tokens",
        "data_type": "TYPE_INT64",
        "dims": [ -1 ]
    },
    {
        "name": "segments_ids",
        "data_type": "TYPE_INT64",
        "dims": [ -1 ]
    }],
    "output": [{
        "name": "logits",
        "data_type": "TYPE_FP32",
        "dims": [ -1, 30522 ]
    }],
    'instance_group': [{
        "count": 1, 
        "kind": "KIND_GPU"
    }]
}

Expected behavior Can you guys help me check the process, I want to know why the model doesn't use Gpus and how can I fix this. Thanks a lot!

jbkyang-nvi commented 1 year ago

What is your client set up to look like?

guoners commented 1 year ago

What is your client set up to look like?

I test single query with Triton http inference protocol. And with a third party LNP tools (locust) with 1 to 10 close-loop (send next query when receive) clients.