NVIDIA / TensorRT-LLM

TensorRT-LLM provides users with an easy-to-use Python API to define Large Language Models (LLMs) and build TensorRT engines that contain state-of-the-art optimizations to perform inference efficiently on NVIDIA GPUs. TensorRT-LLM also contains components to create Python and C++ runtimes that execute those TensorRT engines.
https://nvidia.github.io/TensorRT-LLM
Apache License 2.0
8.31k stars 931 forks source link

Failed to build TRT-LLM engine for quantized INT8 BERT model #1614

Closed hibagus closed 4 months ago

hibagus commented 4 months ago

Hi there,

I am trying to build a custom configuration of BERT for the purpose of performance measuring only. Following the README for GPT, I am using the generate_checkpoint_config.py to generate BERT config with command as follows:

python3 generate_checkpoint_config.py \
    --output_path  bert-int8.json \
    --architecture BertModel \
    --dtype float16 \
    --vocab_size 30522 \
    --max_position_embeddings 768 \
    --hidden_size 768 \
    --intermediate_size 3072 \
    --num_hidden_layers 1 \
    --num_attention_heads 12 \
    --hidden_act gelu \
    --tp_size 1 \
    --pp_size 1 \
    --quant_algo W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN \
    --kv_cache_quant_algo INT8

Then, I run the trtllm-build command as follows:

trtllm-build --model_config bert-int8.json \
             --remove_input_padding enable \
             --strongly_typed \
             --workers 1 \
             --max_batch_size 1 \
             --max_input_len 384 \
             --tp_size 1 \
             --pp_size 1 \
             --output_dir bert-int8-trtlm \

It throws an error: RuntimeError: Unsupported model architecture: BertModel

Looking at init.py, the MODEL_MAP does not include BERT.

Is it not officially supported? Do you have any suggestion how to proceed?

Thanks!

Below is the package version I am using for reference.

Package                  Version
------------------------ -----------------
accelerate               0.27.2
aiohttp                  3.9.5
aiosignal                1.3.1
annotated-types          0.6.0
async-timeout            4.0.3
attrs                    23.2.0
build                    1.2.1
certifi                  2024.2.2
charset-normalizer       3.3.2
cloudpickle              3.0.0
colored                  2.2.4
coloredlogs              15.0.1
cuda-python              12.4.0
datasets                 2.19.1
diffusers                0.15.0
dill                     0.3.8
evaluate                 0.4.2
filelock                 3.14.0
flatbuffers              24.3.25
frozenlist               1.4.1
fsspec                   2024.3.1
h5py                     3.10.0
huggingface-hub          0.23.0
humanfriendly            10.0
idna                     3.7
importlib_metadata       7.1.0
janus                    1.0.0
Jinja2                   3.1.4
joblib                   1.4.2
lark                     1.1.9
markdown-it-py           3.0.0
MarkupSafe               2.1.5
mdurl                    0.1.2
mpi4py                   3.1.4
mpmath                   1.3.0
multidict                6.0.5
multiprocess             0.70.16
networkx                 3.3
ninja                    1.11.1.1
numpy                    1.26.4
nvidia-ammo              0.7.4
nvidia-cublas-cu12       12.1.3.1
nvidia-cuda-cupti-cu12   12.1.105
nvidia-cuda-nvrtc-cu12   12.1.105
nvidia-cuda-runtime-cu12 12.1.105
nvidia-cudnn-cu12        8.9.2.26
nvidia-cufft-cu12        11.0.2.54
nvidia-curand-cu12       10.3.2.106
nvidia-cusolver-cu12     11.4.5.107
nvidia-cusparse-cu12     12.1.0.106
nvidia-modelopt          0.11.2
nvidia-nccl-cu12         2.19.3
nvidia-nvjitlink-cu12    12.4.127
nvidia-nvtx-cu12         12.1.105
onnx                     1.16.0
onnx-graphsurgeon        0.5.2
onnxruntime              1.16.3
optimum                  1.19.2
packaging                24.0
pandas                   2.2.2
pillow                   10.3.0
pip                      24.0
polygraphy               0.49.9
protobuf                 5.26.1
psutil                   5.9.8
PuLP                     2.8.0
pyarrow                  16.1.0
pyarrow-hotfix           0.6
pydantic                 2.7.1
pydantic_core            2.18.2
Pygments                 2.18.0
pynvml                   11.5.0
pyproject_hooks          1.1.0
python-dateutil          2.9.0.post0
pytz                     2024.1
PyYAML                   6.0.1
regex                    2024.5.10
requests                 2.31.0
rich                     13.7.1
safetensors              0.4.3
scikit-learn             1.4.2
scipy                    1.13.0
sentencepiece            0.2.0
setuptools               69.5.1
six                      1.16.0
StrEnum                  0.4.15
sympy                    1.12
tensorrt                 9.3.0.post12.dev1
tensorrt-bindings        9.3.0.post12.dev1
tensorrt-libs            9.3.0.post12.dev1
tensorrt-llm             0.9.0
threadpoolctl            3.5.0
tokenizers               0.15.2
tomli                    2.0.1
torch                    2.2.2
torchinfo                1.8.0
torchsummary             1.5.1
tqdm                     4.66.4
transformers             4.38.2
triton                   2.2.0
typing_extensions        4.11.0
tzdata                   2024.1
urllib3                  2.2.1
wheel                    0.43.0
xxhash                   3.4.1
yarl                     1.9.4
zipp                     3.18.1
byshiue commented 4 months ago

We don't support INT8 Bert in TensorRT-LLM because it is already supported in TensorRT directly.