ELS-RD / transformer-deploy

Efficient, scalable and enterprise-grade CPU/GPU inference server for 🤗 Hugging Face transformer models 🚀
https://els-rd.github.io/transformer-deploy/
Apache License 2.0
1.64k stars 150 forks source link

Dynamic batching does not give better latency for Roberta running on TensorRT. #87

Closed Ki6an closed 2 years ago

Ki6an commented 2 years ago

Hi, I used your build_engine API to convert the Roberta model. While building if I use the constant batch size for input_shapes, i.e. (min, optimal, max) -> (1,1,1) or (4, 4, 4,). The model yields good results (faster than ort and torch).

But when I convert it with dynamic batch size i.e. (min, optimal, max) -> (1, 4, 4), the model performs really slow compared to ort or torch.

code to understand the problem better:

# fast inference but constrained to use always 4 batches during inferencing
tensor_shapes = list(zip([4, 4, 4], [1, 128, 128]))

# slow inference
tensor_shapes = list(zip([1, 4, 4], [1, 128, 128]))

engine: ICudaEngine = build_engine(
    runtime=runtime,
    onnx_file_path=onnx_model_path,
    logger=trt_logger,
    min_shape=tensor_shapes[0],
    optimal_shape=tensor_shapes[1],
    max_shape=tensor_shapes[2],
    workspace_size=workspace_size * 1024**3,
    fp16=not quantization,
    int8=quantization,
    profiling=True,
)

save_engine(engine=engine, engine_file_path=tensorrt_path)

the complete build and inference logs for slow inference case (when converting with dynamic batch)

[06/02/2022-03:19:09] [TRT] [I] [MemUsageChange] Init CUDA: CPU +312, GPU +0, now: CPU 3789, GPU 2470 (MiB)
[06/02/2022-03:19:09] [TRT] [I] [MemUsageChange] Init CUDA: CPU +0, GPU +0, now: CPU 3790, GPU 2470 (MiB)
[06/02/2022-03:19:09] [TRT] [I] [MemUsageSnapshot] Begin constructing builder kernel library: CPU 3790 MiB, GPU 2470 MiB
[06/02/2022-03:19:09] [TRT] [I] [MemUsageSnapshot] End constructing builder kernel library: CPU 3924 MiB, GPU 2504 MiB
[06/02/2022-03:19:09] [TRT] [I] parsing TensorRT model
[libprotobuf WARNING google/protobuf/io/coded_stream.cc:604] Reading dangerously large protocol message.  If the message turns out to be larger than 2147483647 bytes, parsing will be halted for security reasons.  To increase the limit (or to disable these warnings), see CodedInputStream::SetTotalBytesLimit() in google/protobuf/io/coded_stream.h.
[libprotobuf WARNING google/protobuf/io/coded_stream.cc:81] The total number of bytes read was 1418322027
[06/02/2022-03:19:22] [TRT] [W] onnx2trt_utils.cpp:366: Your ONNX model has been generated with INT64 weights, while TensorRT does not natively support INT64. Attempting to cast down to INT32.
[06/02/2022-03:19:39] [TRT] [W] Output type must be INT32 for shape outputs
[06/02/2022-03:19:39] [TRT] [W] Output type must be INT32 for shape outputs
[06/02/2022-03:19:39] [TRT] [W] Output type must be INT32 for shape outputs
[06/02/2022-03:19:43] [TRT] [I] [MemUsageChange] Init cuBLAS/cuBLASLt: CPU +512, GPU +226, now: CPU 5802, GPU 2730 (MiB)
[06/02/2022-03:19:43] [TRT] [I] [MemUsageChange] Init cuDNN: CPU +116, GPU +52, now: CPU 5918, GPU 2782 (MiB)
[06/02/2022-03:19:43] [TRT] [I] Timing cache disabled. Turning it on will improve builder speed.
[06/02/2022-03:19:43] [TRT] [W] Myelin graph with multiple dynamic values may have poor performance if they differ. Dynamic values are: 
[06/02/2022-03:19:43] [TRT] [W]  (# 1 (SHAPE input_ids))
[06/02/2022-03:19:43] [TRT] [W]  (# 0 (SHAPE attention_mask))
Warning: Slice op (Unnamed Layer_ 90) [Slice]_slice cannot slice along a uniform dimension.
Warning: Slice op (Unnamed Layer_ 90) [Slice]_slice cannot slice along a uniform dimension.
Warning: Slice op (Unnamed Layer_ 90) [Slice]_slice cannot slice along a uniform dimension.
Warning: Slice op (Unnamed Layer_ 90) [Slice]_slice cannot slice along a uniform dimension.
Warning: Slice op (Unnamed Layer_ 90) [Slice]_slice cannot slice along a uniform dimension.
Warning: Slice op (Unnamed Layer_ 90) [Slice]_slice cannot slice along a uniform dimension.
Warning: Slice op (Unnamed Layer_ 90) [Slice]_slice cannot slice along a uniform dimension.
[06/02/2022-03:25:32] [TRT] [W] Myelin graph with multiple dynamic values may have poor performance if they differ. Dynamic values are: 
[06/02/2022-03:25:32] [TRT] [W]  (# 1 (SHAPE input_ids))
[06/02/2022-03:25:32] [TRT] [W]  (# 0 (SHAPE attention_mask))
Warning: Slice op (Unnamed Layer_ 90) [Slice]_slice cannot slice along a uniform dimension.
Warning: Slice op (Unnamed Layer_ 90) [Slice]_slice cannot slice along a uniform dimension.
Warning: Slice op (Unnamed Layer_ 90) [Slice]_slice cannot slice along a uniform dimension.
Warning: Slice op (Unnamed Layer_ 90) [Slice]_slice cannot slice along a uniform dimension.
Warning: Slice op (Unnamed Layer_ 90) [Slice]_slice cannot slice along a uniform dimension.
Warning: Slice op (Unnamed Layer_ 90) [Slice]_slice cannot slice along a uniform dimension.
Warning: Slice op (Unnamed Layer_ 90) [Slice]_slice cannot slice along a uniform dimension.
[06/02/2022-03:30:10] [TRT] [I] Detected 2 inputs and 1 output network tensors.
[06/02/2022-03:30:10] [TRT] [W] Myelin graph with multiple dynamic values may have poor performance if they differ. Dynamic values are: 
[06/02/2022-03:30:10] [TRT] [W]  (# 1 (SHAPE input_ids))
[06/02/2022-03:30:10] [TRT] [W]  (# 0 (SHAPE attention_mask))
[06/02/2022-03:30:32] [TRT] [I] Total Host Persistent Memory: 208
[06/02/2022-03:30:32] [TRT] [I] Total Device Persistent Memory: 0
[06/02/2022-03:30:32] [TRT] [I] Total Scratch Memory: 442827264
[06/02/2022-03:30:32] [TRT] [I] [MemUsageStats] Peak memory usage of TRT CPU/GPU memory allocators: CPU 774 MiB, GPU 2058 MiB
[06/02/2022-03:30:32] [TRT] [I] [BlockAssignment] Algorithm ShiftNTopDown took 0.038945ms to assign 4 blocks to 4 nodes requiring 443041280 bytes.
[06/02/2022-03:30:32] [TRT] [I] Total Activation Memory: 443041280
[06/02/2022-03:30:32] [TRT] [I] [MemUsageChange] Init cuBLAS/cuBLASLt: CPU +0, GPU +8, now: CPU 5993, GPU 4298 (MiB)
[06/02/2022-03:30:32] [TRT] [I] [MemUsageChange] Init cuDNN: CPU +0, GPU +8, now: CPU 5993, GPU 4306 (MiB)
[06/02/2022-03:30:32] [TRT] [I] [MemUsageChange] TensorRT-managed allocation in building engine: CPU +0, GPU +1353, now: CPU 0, GPU 1353 (MiB)
[06/02/2022-03:30:33] [TRT] [I] Loaded engine size: 1364 MiB
[06/02/2022-03:30:33] [TRT] [I] [MemUsageChange] Init cuBLAS/cuBLASLt: CPU +0, GPU +10, now: CPU 7354, GPU 4282 (MiB)
[06/02/2022-03:30:33] [TRT] [I] [MemUsageChange] Init cuDNN: CPU +1, GPU +8, now: CPU 7355, GPU 4290 (MiB)
[06/02/2022-03:30:33] [TRT] [I] [MemUsageChange] TensorRT-managed allocation in engine deserialization: CPU +0, GPU +1352, now: CPU 0, GPU 1352 (MiB)
[06/02/2022-03:30:38] [TRT] [I] Loaded engine size: 1364 MiB
[06/02/2022-03:30:38] [TRT] [I] [MemUsageChange] Init cuBLAS/cuBLASLt: CPU +0, GPU +10, now: CPU 7366, GPU 5636 (MiB)
[06/02/2022-03:30:38] [TRT] [I] [MemUsageChange] Init cuDNN: CPU +1, GPU +8, now: CPU 7367, GPU 5644 (MiB)
[06/02/2022-03:30:38] [TRT] [I] [MemUsageChange] TensorRT-managed allocation in engine deserialization: CPU +0, GPU +1352, now: CPU 0, GPU 2704 (MiB)
[06/02/2022-03:30:38] [TRT] [I] [MemUsageChange] Init cuBLAS/cuBLASLt: CPU +0, GPU +10, now: CPU 6002, GPU 5636 (MiB)
[06/02/2022-03:30:38] [TRT] [I] [MemUsageChange] Init cuDNN: CPU +0, GPU +8, now: CPU 6002, GPU 5644 (MiB)
[06/02/2022-03:30:43] [TRT] [I] [MemUsageChange] TensorRT-managed allocation in IExecutionContext creation: CPU +0, GPU +423, now: CPU 0, GPU 3127 (MiB)

latencies in ms
--------------------------------------------------
Pytorch 
--------------------------------------------------
[93.5968, 94.0308, 94.8224, 93.6746, 94.5972, 94.0188, 92.3105, 93.6535, 92.4908, 91.4413]
--------------------------------------------------
Onnxruntime 
 --------------------------------------------------
[81.445, 81.3684, 80.2145, 81.5339, 82.9578, 83.6845, 83.6738, 82.6652, 81.5462, 82.8237]
--------------------------------------------------
TensorRT (FP16) 
 --------------------------------------------------
[426.353, 425.1992, 426.0317, 425.8226, 426.8828, 428.0485, 426.3119, 426.4556, 425.4863, 426.0393]
--------------------------------------------------

Is this the expected behavior?

I want to convert the model to use dynamic batches. When inferencing, the model should be able to handle a variable batch size and perform faster. How can I achieve that?

Any help would be greatly appreciated, thank you in advance.

pommedeterresautee commented 2 years ago

TRT being 4 times slower than Pytorch is not expected results. What GPU / CUDA / TRT version are you using? Are you using Docker version? What is the value of quantization variable?

A thing I do when I am surprised by TRT results is to use the trtexec util through the command line. It does benchmarking at the end of the process and helps to check if there is a code with Python code or it's a TRT issue. Unfortunately, its output is of no use as there are some nodes to keep in FP32 precision to avoid NaN output. Also, very recent version of TRT (>=8.2) got some regressions (at least on GPT-2) so may also be because of that... 8.4.1 should be released soon according to a Nvidia engineer.

Ki6an commented 2 years ago

thank you for the quick response,

I'm using triton-server docker image version 22.04 as the base image.

setup info : GPU - T4 (15GB) CUDA - 11.6 TRT - 8.2.4.2 model - KoichiYasuoka/roberta-large-english-upos Quantization is False, currently focusing on the fp16 version. I tried with nvidia-tensorrt==8.4 and still facing the same issue.

I also built the transformer-deploy container from the source and for constant batch (1,1,1), trt is faster than ort and torch. but when running with dynamic batch (1,4,4) the process gets terminated by saying killed

Did you ever face the following warning while converting any transformers model to trt? if yes, then how did the model perform while inferencing?

[06/02/2022-03:25:32] [TRT] [W] Myelin graph with multiple dynamic values may have poor performance if they differ. Dynamic values are: 
[06/02/2022-03:25:32] [TRT] [W]  (# 1 (SHAPE input_ids))
[06/02/2022-03:25:32] [TRT] [W]  (# 0 (SHAPE attention_mask))
Ki6an commented 2 years ago

ran with the trtexec command inside the tensorrt container 22.04 container

trtexec --onnx=onnx/model.onnx --saveEngine=tensorrt/model.plan \
  --minShapes=input_ids:1x1,attention_mask:1x1  \
  --optShapes=input_ids:4x128,attention_mask:4x128 \
  --maxShapes=input_ids:4x128,attention_mask:4x128 \
  --workspace=10000 \
  --fp16 \
  --verbose 

and getting the same results... the accuracy is also not good :/

[06/02/2022-15:54:08] [I] === Performance summary ===
[06/02/2022-15:54:08] [I] Throughput: 2.28208 qps
[06/02/2022-15:54:08] [I] Latency: min = 430.552 ms, max = 450.252 ms, mean = 438.171 ms, median = 437.374 ms, percentile(99%) = 450.252 ms
[06/02/2022-15:54:08] [I] End-to-End Host Latency: min = 430.575 ms, max = 450.27 ms, mean = 438.188 ms, median = 437.391 ms, percentile(99%) = 450.27 ms
[06/02/2022-15:54:08] [I] Enqueue Time: min = 430.497 ms, max = 450.193 ms, mean = 438.105 ms, median = 437.305 ms, percentile(99%) = 450.193 ms
[06/02/2022-15:54:08] [I] H2D Latency: min = 0.00878906 ms, max = 0.020874 ms, mean = 0.0117126 ms, median = 0.010437 ms, percentile(99%) = 0.020874 ms
[06/02/2022-15:54:08] [I] GPU Compute Time: min = 430.491 ms, max = 450.203 ms, mean = 438.119 ms, median = 437.324 ms, percentile(99%) = 450.203 ms
[06/02/2022-15:54:08] [I] D2H Latency: min = 0.0397949 ms, max = 0.0415039 ms, mean = 0.0405151 ms, median = 0.0405273 ms, percentile(99%) = 0.0415039 ms
pommedeterresautee commented 2 years ago

Thank you, the error message is interesting. I will dig into this on my side.

pommedeterresautee commented 2 years ago

I have been able to reproduce your results. The error message regarding dynamic axis means that for best performances you should only use one dynamic axis (batch or seq len). https://forums.developer.nvidia.com/t/myelin-graph-error-when-converting-to-trt-engine-inference/202019/6

I have tried to keep batch size fixed and results are much better :-)

❯ convert_model -m KoichiYasuoka/roberta-large-english-upos  --backend tensorrt --seq-len 1 128 128 --batch 4 4 4 --task token-classification --verbose -w 20000 --fast
06/03/2022 17:40:23 INFO     running with commands: Namespace(model='KoichiYasuoka/roberta-large-english-upos', tokenizer=None, task='token-classification', auth_token=None, batch_size=[4, 4, 4], seq_len=[1, 128, 128], quantization=False, workspace_size=20000, output='triton_models', name='transformer', verbose=True, fast=True, backend=['tensorrt'], device=None, nb_threads=1, nb_instances=1, warmup=10, nb_measures=1000, seed=123, atol=0.3)
06/03/2022 17:40:31 INFO     axis: ['input_ids', 'attention_mask']
06/03/2022 17:40:52 INFO     running Pytorch (FP32) benchmark
06/03/2022 17:41:11 INFO     cleaning up
06/03/2022 17:41:11 INFO     preparing TensorRT (FP16) benchmark
[06/03/2022-17:41:11] [TRT] [I] [MemUsageChange] Init CUDA: CPU +448, GPU +0, now: CPU 5671, GPU 2671 (MiB)
[06/03/2022-17:41:11] [TRT] [I] [MemUsageChange] Init CUDA: CPU +0, GPU +0, now: CPU 5671, GPU 2671 (MiB)
[06/03/2022-17:41:11] [TRT] [I] [MemUsageSnapshot] Begin constructing builder kernel library: CPU 5671 MiB, GPU 2671 MiB
[06/03/2022-17:41:11] [TRT] [I] [MemUsageSnapshot] End constructing builder kernel library: CPU 5825 MiB, GPU 2713 MiB
[06/03/2022-17:41:11] [TRT] [I] parsing TensorRT model
[libprotobuf WARNING google/protobuf/io/coded_stream.cc:604] Reading dangerously large protocol message.  If the message turns out to be larger than 2147483647 bytes, parsing will be halted for security reasons.  To increase the limit (or to disable these warnings), see CodedInputStream::SetTotalBytesLimit() in google/protobuf/io/coded_stream.h.
[libprotobuf WARNING google/protobuf/io/coded_stream.cc:81] The total number of bytes read was 1418322093
[06/03/2022-17:41:13] [TRT] [W] onnx2trt_utils.cpp:366: Your ONNX model has been generated with INT64 weights, while TensorRT does not natively support INT64. Attempting to cast down to INT32.
[06/03/2022-17:41:26] [TRT] [W] Output type must be INT32 for shape outputs
[06/03/2022-17:41:26] [TRT] [W] Output type must be INT32 for shape outputs
[06/03/2022-17:41:26] [TRT] [W] Output type must be INT32 for shape outputs
[06/03/2022-17:41:26] [TRT] [W] building engine. depending on model size this may take a while
[06/03/2022-17:41:29] [TRT] [I] [MemUsageChange] Init cuBLAS/cuBLASLt: CPU +838, GPU +362, now: CPU 8029, GPU 3086 (MiB)
[06/03/2022-17:41:29] [TRT] [I] [MemUsageChange] Init cuDNN: CPU +127, GPU +60, now: CPU 8156, GPU 3146 (MiB)
[06/03/2022-17:41:29] [TRT] [I] Timing cache disabled. Turning it on will improve builder speed.
Warning: Slice op (Unnamed Layer_ 91) [Slice]_slice cannot slice along a uniform dimension.
Warning: Slice op (Unnamed Layer_ 91) [Slice]_slice cannot slice along a uniform dimension.
Warning: Slice op (Unnamed Layer_ 91) [Slice]_slice cannot slice along a uniform dimension.
Warning: Slice op (Unnamed Layer_ 91) [Slice]_slice cannot slice along a uniform dimension.
Warning: Slice op (Unnamed Layer_ 91) [Slice]_slice cannot slice along a uniform dimension.
Warning: Slice op (Unnamed Layer_ 91) [Slice]_slice cannot slice along a uniform dimension.
Warning: Slice op (Unnamed Layer_ 91) [Slice]_slice cannot slice along a uniform dimension.
Warning: Slice op (Unnamed Layer_ 91) [Slice]_slice cannot slice along a uniform dimension.
Warning: Slice op (Unnamed Layer_ 91) [Slice]_slice cannot slice along a uniform dimension.
Warning: Slice op (Unnamed Layer_ 91) [Slice]_slice cannot slice along a uniform dimension.
Warning: Slice op (Unnamed Layer_ 91) [Slice]_slice cannot slice along a uniform dimension.
Warning: Slice op (Unnamed Layer_ 91) [Slice]_slice cannot slice along a uniform dimension.
Warning: Slice op (Unnamed Layer_ 91) [Slice]_slice cannot slice along a uniform dimension.
Warning: Slice op (Unnamed Layer_ 91) [Slice]_slice cannot slice along a uniform dimension.
Warning: Slice op (Unnamed Layer_ 91) [Slice]_slice cannot slice along a uniform dimension.
Warning: Slice op (Unnamed Layer_ 91) [Slice]_slice cannot slice along a uniform dimension.
[06/03/2022-17:45:37] [TRT] [I] Detected 2 inputs and 1 output network tensors.
[06/03/2022-17:45:59] [TRT] [I] Total Host Persistent Memory: 240
[06/03/2022-17:45:59] [TRT] [I] Total Device Persistent Memory: 0
[06/03/2022-17:45:59] [TRT] [I] Total Scratch Memory: 151762432
[06/03/2022-17:45:59] [TRT] [I] [MemUsageStats] Peak memory usage of TRT CPU/GPU memory allocators: CPU 774 MiB, GPU 1643 MiB
[06/03/2022-17:45:59] [TRT] [I] [BlockAssignment] Algorithm ShiftNTopDown took 0.037546ms to assign 3 blocks to 3 nodes requiring 151975936 bytes.
[06/03/2022-17:45:59] [TRT] [I] Total Activation Memory: 151975936
[06/03/2022-17:45:59] [TRT] [I] [MemUsageChange] Init cuBLAS/cuBLASLt: CPU +0, GPU +8, now: CPU 8170, GPU 3883 (MiB)
[06/03/2022-17:45:59] [TRT] [I] [MemUsageChange] Init cuDNN: CPU +0, GPU +8, now: CPU 8170, GPU 3891 (MiB)
[06/03/2022-17:45:59] [TRT] [I] [MemUsageChange] TensorRT-managed allocation in building engine: CPU +0, GPU +677, now: CPU 0, GPU 677 (MiB)
[06/03/2022-17:45:59] [TRT] [I] Loaded engine size: 689 MiB
[06/03/2022-17:46:00] [TRT] [I] [MemUsageChange] Init cuBLAS/cuBLASLt: CPU +0, GPU +8, now: CPU 8857, GPU 3867 (MiB)
[06/03/2022-17:46:00] [TRT] [I] [MemUsageChange] Init cuDNN: CPU +0, GPU +8, now: CPU 8857, GPU 3875 (MiB)
[06/03/2022-17:46:00] [TRT] [I] [MemUsageChange] TensorRT-managed allocation in engine deserialization: CPU +0, GPU +676, now: CPU 0, GPU 676 (MiB)
[06/03/2022-17:46:00] [TRT] [W] building engine took 273.7 seconds
[06/03/2022-17:46:02] [TRT] [I] Loaded engine size: 689 MiB
[06/03/2022-17:46:02] [TRT] [I] [MemUsageChange] Init cuBLAS/cuBLASLt: CPU +0, GPU +8, now: CPU 8870, GPU 4545 (MiB)
[06/03/2022-17:46:02] [TRT] [I] [MemUsageChange] Init cuDNN: CPU +0, GPU +8, now: CPU 8870, GPU 4553 (MiB)
[06/03/2022-17:46:02] [TRT] [I] [MemUsageChange] TensorRT-managed allocation in engine deserialization: CPU +0, GPU +676, now: CPU 0, GPU 1352 (MiB)
[06/03/2022-17:46:02] [TRT] [I] [MemUsageChange] Init cuBLAS/cuBLASLt: CPU +0, GPU +8, now: CPU 8181, GPU 4545 (MiB)
[06/03/2022-17:46:02] [TRT] [I] [MemUsageChange] Init cuDNN: CPU +0, GPU +8, now: CPU 8181, GPU 4553 (MiB)
[06/03/2022-17:46:06] [TRT] [I] [MemUsageChange] TensorRT-managed allocation in IExecutionContext creation: CPU +0, GPU +145, now: CPU 0, GPU 1497 (MiB)
06/03/2022 17:46:06 INFO     running TensorRT (FP16) benchmark
Inference done on NVIDIA GeForce RTX 3090
latencies:
[Pytorch (FP32)] mean=17.37ms, sd=0.81ms, min=16.50ms, max=24.00ms, median=17.20ms, 95p=18.49ms, 99p=20.91ms
[TensorRT (FP16)] mean=7.96ms, sd=0.15ms, min=7.82ms, max=9.32ms, median=7.94ms, 95p=8.15ms, 99p=8.65ms
Each infence engine output is within 0.3 tolerance compared to Pytorch output

Don't know your use case, but we had something similar and in our own use case, we added fake data to make axis fixed... It only works if most of your batch are full.

I have already noticed this behavior but never in those proportions.


for memory, when batch and seq len are both dynamic:

Inference done on NVIDIA GeForce RTX 3090
latencies:
[Pytorch (FP32)] mean=17.03ms, sd=0.85ms, min=16.51ms, max=31.24ms, median=16.81ms, 95p=17.88ms, 99p=21.03ms
[Pytorch (FP16)] mean=17.24ms, sd=0.45ms, min=16.83ms, max=20.76ms, median=17.08ms, 95p=18.24ms, 99p=18.82ms
[TensorRT (FP16)] mean=179.61ms, sd=5.90ms, min=169.19ms, max=202.33ms, median=178.15ms, 95p=192.51ms, 99p=197.33ms
Ki6an commented 2 years ago

Thank you for looking into this, really appreciate it.

for best performance you should only use one dynamic axis (batch or seq-len).

Does this only apply to this model, or does it apply to all transformer models?

Don't know your use case, but we had something similar and in our own use case, we added fake data to make axis fixed... It only works if most of your batch are full.

We are also doing the same; we converted the model with a fixed batch size of (4,4,4) and now the model accepts exactly 4 batches. if the input batch to the model is <4 we fill the remaining items with empty strings to make it full.

pommedeterresautee commented 2 years ago

My understanding is that Myelin engine is used for all transformers model (never saw any other engine used TRT logs).

sam-writer commented 2 years ago

@pommedeterresautee do you know if it is better to pad sequences or batches if both are variable?

pommedeterresautee commented 2 years ago

It depends of your use case and definitely requires real benchmark. In my own usecase, I got something like 10 batches of documents with 9 of the same size and 1 a bit smaller. Adding fake docs to the last one cost me very little compared to the boost offered by TRT. On the other side, my docs can be as small as 100 tokens to 400 tokens, so padding cost me a lot.

If you are working on super short documents (1-30 tokens), I would say that seq len is a viable option, but again, requires real measures, depends of your GPU, etc.

sam-writer commented 2 years ago

Thank you for the detailed answer

Ki6an commented 2 years ago

upgrading to nvidia-tensorrt==8.4.1.5 solves this issue.

pommedeterresautee commented 2 years ago

Thank you for the info